diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2711796..1fe32cd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,13 +7,14 @@ To contribute, please follow these steps: 1. Create an issue explaining what you'd like to fix or add. This way, we can approve and discuss the solution before any time is spent on developing it. 2. Fork the upstream repository into a personal account. -3. Install [poetry](https://python-poetry.org/), and install all dependencies using ``poetry install`` +3. Install [poetry](https://python-poetry.org/), and install all dependencies using ``poetry install --with dev`` 4. Activate the environment by running ``poetry shell`` 5. Install [pre-commit](https://pre-commit.com/) (for project linting) by running ``pre-commit install`` -6. Create a new branch for your changes, and make sure to add tests! -7. Push the topic branch to your personal fork -8. Run `pre-commit run --all-files` locally to ensure proper linting -9. Create a pull request to the Intility repository with a detailed summary of your changes and what motivated the change +6. Create a new branch for your changes. +7. Create and run tests with full coverage by running `poetry run pytest --cov fastapi_azure_auth --cov-report=term-missing` +8. Push the topic branch to your personal fork. +9. Run `pre-commit run --all-files` locally to ensure proper linting. +10. Create a pull request to the intility repository with a detailed summary of your changes and what motivated the change. If you need a more detailed walk through, please see this [issue comment](https://github.com/Intility/fastapi-azure-auth/issues/49#issuecomment-1056962282). diff --git a/demo_project/api/dependencies.py b/demo_project/api/dependencies.py index 07c2727..700fca9 100644 --- a/demo_project/api/dependencies.py +++ b/demo_project/api/dependencies.py @@ -11,7 +11,7 @@ MultiTenantAzureAuthorizationCodeBearer, SingleTenantAzureAuthorizationCodeBearer, ) -from fastapi_azure_auth.exceptions import InvalidAuthHttp +from fastapi_azure_auth.exceptions import ForbiddenHttp, UnauthorizedHttp from fastapi_azure_auth.user import User log = logging.getLogger(__name__) @@ -30,7 +30,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None: Raises a 401 authentication error if not. """ if 'AdminUser' not in user.roles: - raise InvalidAuthHttp('User is not an AdminUser') + raise ForbiddenHttp('User is not an AdminUser') class IssuerFetcher: @@ -44,7 +44,7 @@ def __init__(self) -> None: async def __call__(self, tid: str) -> str: """ Check if memory cache needs to be updated or not, and then returns an issuer for a given tenant - :raises InvalidAuth when it's not a valid tenant + :raises Unauthorized when it's not a valid tenant """ refresh_time = datetime.now() - timedelta(hours=1) if not self._config_timestamp or self._config_timestamp < refresh_time: @@ -58,7 +58,7 @@ async def __call__(self, tid: str) -> str: return self.tid_to_iss[tid] except Exception as error: log.exception('`iss` not found for `tid` %s. Error %s', tid, error) - raise InvalidAuthHttp('You must be an Intility customer to access this resource') + raise UnauthorizedHttp('You must be an Intility customer to access this resource') issuer_fetcher = IssuerFetcher() @@ -101,7 +101,7 @@ async def multi_auth( return azure_auth if api_key == 'JonasIsCool': return api_key - raise InvalidAuthHttp('You must either provide a valid bearer token or API key') + raise UnauthorizedHttp('You must either provide a valid bearer token or API key') async def multi_auth_b2c( @@ -115,4 +115,4 @@ async def multi_auth_b2c( return azure_auth if api_key == 'JonasIsCool': return api_key - raise InvalidAuthHttp('You must either provide a valid bearer token or API key') + raise UnauthorizedHttp('You must either provide a valid bearer token or API key') diff --git a/demo_project/core/config.py b/demo_project/core/config.py index e813b79..fe64b7f 100644 --- a/demo_project/core/config.py +++ b/demo_project/core/config.py @@ -13,9 +13,9 @@ class AzureActiveDirectory(BaseSettings): # type: ignore[misc, valid-type] OPENAPI_CLIENT_ID: str = Field(default='') TENANT_ID: str = Field(default='') APP_CLIENT_ID: str = Field(default='') - AUTH_URL: AnyHttpUrl = Field(default='https://dummy.com/') - CONFIG_URL: AnyHttpUrl = Field(default='https://dummy.com/') - TOKEN_URL: AnyHttpUrl = Field(default='https://dummy.com/') + AUTH_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/')) + CONFIG_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/')) + TOKEN_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/')) GRAPH_SECRET: str = Field(default='') CLIENT_SECRET: str = Field(default='') diff --git a/docs/docs/multi-tenant/accept_specific_tenants_only.mdx b/docs/docs/multi-tenant/accept_specific_tenants_only.mdx index 40a2c97..d9d2d1d 100644 --- a/docs/docs/multi-tenant/accept_specific_tenants_only.mdx +++ b/docs/docs/multi-tenant/accept_specific_tenants_only.mdx @@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer -from fastapi_azure_auth.exceptions import InvalidAuth +from fastapi_azure_auth.exceptions import Unauthorized class Settings(BaseSettings): @@ -56,7 +56,7 @@ async def check_if_valid_tenant(tid: str) -> str: try: return tid_to_iss_mapping[tid] except KeyError: - raise InvalidAuth('Tenant not allowed') + raise Unauthorized('Tenant not allowed') azure_scheme = MultiTenantAzureAuthorizationCodeBearer( app_client_id=settings.APP_CLIENT_ID, @@ -86,7 +86,7 @@ if __name__ == '__main__': ``` We're first creating an `async function`, which takes a `tid` as an argument, and returns the tenant ID's `iss` if it's a valid tenant. -If it's not a valid tenant, it has to raise an `InvalidAuth()` exception. +If it's not a valid tenant, it has to raise an `Unauthorized()` exception. ## More sophisticated callable If you want to cache these results in memory, you can do so by creating a more sophisticated callable: @@ -103,7 +103,7 @@ class IssuerFetcher: async def __call__(self, tid: str) -> str: """ Check if memory cache needs to be updated or not, and then returns an issuer for a given tenant - :raises InvalidAuth when it's not a valid tenant + :raises Unauthorized when it's not a valid tenant """ refresh_time = datetime.now() - timedelta(hours=1) if not self._config_timestamp or self._config_timestamp < refresh_time: @@ -117,7 +117,7 @@ class IssuerFetcher: return self.tid_to_iss[tid] except Exception as error: log.exception('`iss` not found for `tid` %s. Error %s', tid, error) - raise InvalidAuth('You must be an Intility customer to access this resource') + raise Unauthorized('You must be an Intility customer to access this resource') issuer_fetcher = IssuerFetcher() diff --git a/docs/docs/usage-and-faq/guest_users.mdx b/docs/docs/usage-and-faq/guest_users.mdx index dbbfc9f..2a93800 100644 --- a/docs/docs/usage-and-faq/guest_users.mdx +++ b/docs/docs/usage-and-faq/guest_users.mdx @@ -39,7 +39,7 @@ would like to lock down specific endpoints. ```python title="security.py" from fastapi import Depends -from fastapi_azure_auth.exceptions import InvalidAuth +from fastapi_azure_auth.exceptions import Unauthorized from fastapi_azure_auth.user import User async def deny_guest_users(user: User = Depends(azure_scheme)) -> None: @@ -47,7 +47,7 @@ async def deny_guest_users(user: User = Depends(azure_scheme)) -> None: Deny guest users """ if user.is_guest: - raise InvalidAuth('Guest user not allowed') + raise Unauthorized('Guest user not allowed') ``` @@ -57,7 +57,7 @@ Alternatively, after [FastAPI 0.95.0](https://github.com/tiangolo/fastapi/releas ```python title="security.py" from typing import Annotated from fastapi import Depends -from fastapi_azure_auth.exceptions import InvalidAuth +from fastapi_azure_auth.exceptions import Unauthorized from fastapi_azure_auth.user import User async def deny_guest_users(user: User = Depends(azure_scheme)) -> None: @@ -65,7 +65,7 @@ async def deny_guest_users(user: User = Depends(azure_scheme)) -> None: Deny guest users """ if user.is_guest: - raise InvalidAuth('Guest user not allowed') + raise Unauthorized('Guest user not allowed') NonGuestUser = Annotated[User, Depends(deny_guest_users)] ``` diff --git a/docs/docs/usage-and-faq/locking_down_on_roles.mdx b/docs/docs/usage-and-faq/locking_down_on_roles.mdx index e19a7ec..a783d48 100644 --- a/docs/docs/usage-and-faq/locking_down_on_roles.mdx +++ b/docs/docs/usage-and-faq/locking_down_on_roles.mdx @@ -30,7 +30,7 @@ You can lock down on roles by creating your own wrapper dependency: ```python title="dependencies.py" from fastapi import Depends -from fastapi_azure_auth.exceptions import InvalidAuth +from fastapi_azure_auth.exceptions import Unauthorized from fastapi_azure_auth.user import User async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None: @@ -39,7 +39,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None: Raises a 401 authentication error if not. """ if 'AdminUser' not in user.roles: - raise InvalidAuth('User is not an AdminUser') + raise Unauthorized('User is not an AdminUser') ``` and then use this dependency over `azure_scheme`. @@ -51,7 +51,7 @@ Alternatively, after [FastAPI 0.95.0](https://github.com/tiangolo/fastapi/releas ```python title="security.py" from typing import Annotated from fastapi import Depends -from fastapi_azure_auth.exceptions import InvalidAuth +from fastapi_azure_auth.exceptions import Unauthorized from fastapi_azure_auth.user import User async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None: @@ -60,7 +60,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None: Raises a 401 authentication error if not. """ if 'AdminUser' not in user.roles: - raise InvalidAuth('User is not an AdminUser') + raise Unauthorized('User is not an AdminUser') AdminUser = Annotated[User, Depends(validate_is_admin_user)] ``` diff --git a/fastapi_azure_auth/auth.py b/fastapi_azure_auth/auth.py index 38a6e3f..bf82873 100644 --- a/fastapi_azure_auth/auth.py +++ b/fastapi_azure_auth/auth.py @@ -17,7 +17,18 @@ ) from starlette.requests import HTTPConnection -from fastapi_azure_auth.exceptions import InvalidAuth, InvalidAuthHttp, InvalidAuthWebSocket +from fastapi_azure_auth.exceptions import ( + Forbidden, + ForbiddenHttp, + ForbiddenWebSocket, + InvalidAuthHttp, + InvalidAuthWebSocket, + InvalidRequest, + InvalidRequestHttp, + Unauthorized, + UnauthorizedHttp, + UnauthorizedWebSocket, +) from fastapi_azure_auth.openid_config import OpenIdConfig from fastapi_azure_auth.user import User from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest @@ -148,28 +159,28 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope access_token = await self.extract_access_token(request) try: if access_token is None: - raise InvalidAuth('No access token provided', request=request) + raise InvalidRequest('No access token provided', request=request) # Extract header information of the token. header: dict[str, Any] = get_unverified_header(access_token) claims: dict[str, Any] = get_unverified_claims(access_token) except Exception as error: log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True) - raise InvalidAuth(detail='Invalid token format', request=request) from error + raise Unauthorized(detail='Invalid token format', request=request) from error user_is_guest: bool = is_guest(claims=claims) if not self.allow_guest_users and user_is_guest: log.info('User denied, is a guest user', claims) - raise InvalidAuth(detail='Guest users not allowed', request=request) + raise Forbidden(detail='Guest users not allowed', request=request) for scope in security_scopes.scopes: token_scope_string = claims.get('scp', '') log.debug('Scopes: %s', token_scope_string) if not isinstance(token_scope_string, str): - raise InvalidAuth('Token contains invalid formatted scopes', request=request) + raise Forbidden('Token contains invalid formatted scopes', request=request) token_scopes = token_scope_string.split(' ') if scope not in token_scopes: - raise InvalidAuth('Required scope missing', request=request) + raise Forbidden('Required scope missing', request=request) # Load new config if old await self.openid_config.load_config() @@ -211,27 +222,36 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope MissingRequiredClaimError, ) as error: log.info('Token contains invalid claims. %s', error) - raise InvalidAuth(detail='Token contains invalid claims', request=request) from error + raise Unauthorized(detail='Token contains invalid claims', request=request) from error except ExpiredSignatureError as error: log.info('Token signature has expired. %s', error) - raise InvalidAuth(detail='Token signature has expired', request=request) from error + raise Unauthorized(detail='Token signature has expired', request=request) from error except InvalidTokenError as error: log.warning('Invalid token. Error: %s', error, exc_info=True) - raise InvalidAuth(detail='Unable to validate token', request=request) from error + raise Unauthorized(detail='Unable to validate token', request=request) from error except Exception as error: # Extra failsafe in case of a bug in a future version of the jwt library log.exception('Unable to process jwt token. Uncaught error: %s', error) - raise InvalidAuth(detail='Unable to process token', request=request) from error + raise Unauthorized(detail='Unable to process token', request=request) from error log.warning('Unable to verify token. No signing keys found') - raise InvalidAuth(detail='Unable to verify token, no signing keys found', request=request) - except (InvalidAuthHttp, InvalidAuthWebSocket, HTTPException): + raise Unauthorized(detail='Unable to verify token, no signing keys found', request=request) + except ( + InvalidAuthHttp, + InvalidAuthWebSocket, + InvalidRequestHttp, + UnauthorizedHttp, + UnauthorizedWebSocket, + ForbiddenHttp, + ForbiddenWebSocket, + HTTPException, + ): if not self.auto_error: return None raise except Exception as error: if not self.auto_error: return None - raise InvalidAuth(detail='Unable to validate token', request=request) from error + raise InvalidRequest(detail='Unable to validate token', request=request) from error async def extract_access_token(self, request: HTTPConnection) -> Optional[str]: """ diff --git a/fastapi_azure_auth/exceptions.py b/fastapi_azure_auth/exceptions.py index 8e778bf..23c87ff 100644 --- a/fastapi_azure_auth/exceptions.py +++ b/fastapi_azure_auth/exceptions.py @@ -4,33 +4,114 @@ from starlette.requests import HTTPConnection -class InvalidAuthHttp(HTTPException): - """ - Exception raised when the user is not authorized over HTTP - """ +class InvalidRequestHttp(HTTPException): + """HTTP exception for malformed/invalid requests""" def __init__(self, detail: str) -> None: super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, headers={'WWW-Authenticate': 'Bearer'} + status_code=status.HTTP_400_BAD_REQUEST, detail={"error": "invalid_request", "message": detail} ) -class InvalidAuthWebSocket(WebSocketException): - """ - Exception raised when the user is not authorized over WebSockets - """ +class InvalidRequestWebSocket(WebSocketException): + """WebSocket exception for malformed/invalid requests""" + + def __init__(self, detail: str) -> None: + super().__init__( + code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "invalid_request", "message": detail}) + ) + + +class UnauthorizedHttp(HTTPException): + """HTTP exception for authentication failures""" + + def __init__(self, detail: str) -> None: + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"error": "invalid_token", "message": detail}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + +class UnauthorizedWebSocket(WebSocketException): + """WebSocket exception for authentication failures""" + + def __init__(self, detail: str) -> None: + super().__init__( + code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "invalid_token", "message": detail}) + ) + + +class ForbiddenHttp(HTTPException): + """HTTP exception for insufficient permissions""" def __init__(self, detail: str) -> None: super().__init__( - code=status.WS_1008_POLICY_VIOLATION, - reason=detail, + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "insufficient_scope", "message": detail}, + headers={"WWW-Authenticate": "Bearer"}, ) -def InvalidAuth(detail: str, request: HTTPConnection) -> InvalidAuthHttp | InvalidAuthWebSocket: +class ForbiddenWebSocket(WebSocketException): + """WebSocket exception for insufficient permissions""" + + def __init__(self, detail: str) -> None: + super().__init__( + code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "insufficient_scope", "message": detail}) + ) + + +# --- start backwards-compatible code --- +def InvalidAuth(detail: str, request: HTTPConnection) -> UnauthorizedHttp | UnauthorizedWebSocket: """ - Returns the correct exception based on the connection type + Legacy factory function that maps to Unauthorized for backwards compatibility. + Returns the correct exception based on the connection type. + TODO: Remove in v6.0.0 """ if request.scope['type'] == 'http': - return InvalidAuthHttp(detail) - return InvalidAuthWebSocket(detail) + # Convert the legacy format to new format + return UnauthorizedHttp(detail) + return UnauthorizedWebSocket(detail) + + +class InvalidAuthHttp(UnauthorizedHttp): + """Legacy HTTP exception class that maps to UnauthorizedHttp + TODO: Remove in v6.0.0 + """ + + def __init__(self, detail: str) -> None: + super().__init__(detail) + + +class InvalidAuthWebSocket(UnauthorizedWebSocket): + """Legacy WebSocket exception class that maps to UnauthorizedWebSocket + TODO: Remove in v6.0.0 + """ + + def __init__(self, detail: str) -> None: + super().__init__(detail) + + +# --- end backwards-compatible code --- + + +def InvalidRequest(detail: str, request: HTTPConnection) -> InvalidRequestHttp | InvalidRequestWebSocket: + """Factory function for invalid request exceptions (HTTP only, as request validation happens pre-connection)""" + if request.scope['type'] == 'http': + return InvalidRequestHttp(detail) + return InvalidRequestWebSocket(detail) + + +def Unauthorized(detail: str, request: HTTPConnection) -> UnauthorizedHttp | UnauthorizedWebSocket: + """Factory function for unauthorized exceptions""" + if request.scope["type"] == "http": + return UnauthorizedHttp(detail) + return UnauthorizedWebSocket(detail) + + +def Forbidden(detail: str, request: HTTPConnection) -> ForbiddenHttp | ForbiddenWebSocket: + """Factory function for forbidden exceptions""" + if request.scope["type"] == "http": + return ForbiddenHttp(detail) + return ForbiddenWebSocket(detail) diff --git a/tests/multi_tenant/multi_auth/test_auto_error.py b/tests/multi_tenant/multi_auth/test_auto_error.py index 5c404c3..0c4e95d 100644 --- a/tests/multi_tenant/multi_auth/test_auto_error.py +++ b/tests/multi_tenant/multi_auth/test_auto_error.py @@ -12,6 +12,7 @@ async def test_normal_azure_user_valid_token(multi_tenant_app, mock_openid_and_k ) as ac: response = await ac.get('api/v1/hello-multi-auth') assert response.json() == {'api_key': False, 'azure_auth': True} + assert response.status_code == 200 @pytest.mark.anyio @@ -21,6 +22,7 @@ async def test_api_key_valid_key(multi_tenant_app, mock_openid_and_keys): ) as ac: response = await ac.get('api/v1/hello-multi-auth') assert response.json() == {'api_key': True, 'azure_auth': False} + assert response.status_code == 200 @pytest.mark.anyio @@ -30,7 +32,10 @@ async def test_normal_azure_user_but_invalid_token(multi_tenant_app, mock_openid transport=ASGITransport(app=app), base_url='http://test', headers={'Authorization': 'Bearer ' + access_token} ) as ac: response = await ac.get('api/v1/hello-multi-auth') - assert response.json() == {'detail': 'You must either provide a valid bearer token or API key'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'You must either provide a valid bearer token or API key'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -39,4 +44,7 @@ async def test_api_key_but_invalid_key(multi_tenant_app, mock_openid_and_keys): transport=ASGITransport(app=app), base_url='http://test', headers={'TEST-API-KEY': 'JonasIsNotCool'} ) as ac: response = await ac.get('api/v1/hello-multi-auth') - assert response.json() == {'detail': 'You must either provide a valid bearer token or API key'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'You must either provide a valid bearer token or API key'} + } + assert response.status_code == 401 diff --git a/tests/multi_tenant/test_multi_tenant.py b/tests/multi_tenant/test_multi_tenant.py index 9c9c3fa..d648921 100644 --- a/tests/multi_tenant/test_multi_tenant.py +++ b/tests/multi_tenant/test_multi_tenant.py @@ -19,7 +19,7 @@ from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase -from fastapi_azure_auth.exceptions import InvalidAuthHttp +from fastapi_azure_auth.exceptions import UnauthorizedHttp @pytest.mark.anyio @@ -124,13 +124,16 @@ async def test_no_keys_to_decode_with(multi_tenant_app, mock_openid_and_empty_ke app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio async def test_iss_callable_raise_error(mock_openid_and_keys): async def issuer_fetcher(tid): - raise InvalidAuthHttp(f'Tenant {tid} not a valid tenant') + raise UnauthorizedHttp(f'Tenant {tid} not a valid tenant') azure_scheme_overrides = generate_azure_scheme_multi_tenant_object(issuer_fetcher) @@ -139,7 +142,10 @@ async def issuer_fetcher(tid): app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Tenant intility_tenant_id not a valid tenant'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Tenant intility_tenant_id not a valid tenant'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -167,7 +173,8 @@ async def test_normal_user_rejected(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_normal_user()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'User is not an AdminUser'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'User is not an AdminUser'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -178,7 +185,8 @@ async def test_guest_user_rejected(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_guest_user()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Guest users not allowed'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Guest users not allowed'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -189,7 +197,8 @@ async def test_invalid_token_claims(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid claims'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token contains invalid claims'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -200,7 +209,10 @@ async def test_no_valid_keys_for_token(multi_tenant_app, mock_openid_and_no_vali headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -211,7 +223,8 @@ async def test_no_valid_scopes(multi_tenant_app, mock_openid_and_no_valid_keys): headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Required scope missing'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Required scope missing'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -222,7 +235,10 @@ async def test_no_valid_invalid_formatted_scope(multi_tenant_app, mock_openid_an headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes(scopes=None)}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid formatted scopes'} + assert response.json() == { + 'detail': {'error': 'insufficient_scope', 'message': 'Token contains invalid formatted scopes'} + } + assert response.status_code == 403 @pytest.mark.anyio @@ -233,7 +249,8 @@ async def test_expired_token(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token signature has expired'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token signature has expired'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -245,7 +262,12 @@ async def test_evil_token(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_evil_access_token()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to validate token'} + assert ( + response.json() + == {'detail': {'error': 'invalid_token', 'message': 'Unable to validate token'}} + != {'detail': 'Unable to validate token'} + ) + assert response.status_code == 401 @pytest.mark.anyio @@ -255,7 +277,8 @@ async def test_malformed_token(multi_tenant_app, mock_openid_and_keys): app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -270,7 +293,8 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys): }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -282,7 +306,8 @@ async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to process token'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Unable to process token'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -306,4 +331,7 @@ async def test_change_of_keys_works(multi_tenant_app, mock_openid_ok_then_empty, app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: second_resonse = await ac.get('api/v1/hello') - assert second_resonse.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert second_resonse.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert second_resonse.status_code == 401 diff --git a/tests/multi_tenant/test_websocket.py b/tests/multi_tenant/test_websocket.py index 09d4535..ed54e4c 100644 --- a/tests/multi_tenant/test_websocket.py +++ b/tests/multi_tenant/test_websocket.py @@ -1,3 +1,4 @@ +import json from typing import Annotated import pytest @@ -20,7 +21,7 @@ from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase -from fastapi_azure_auth.exceptions import InvalidAuthWebSocket +from fastapi_azure_auth.exceptions import ForbiddenWebSocket, UnauthorizedWebSocket from fastapi_azure_auth.openid_config import OpenIdConfig from fastapi_azure_auth.user import User @@ -31,7 +32,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> User: Raises a 401 authentication error if not. """ if 'AdminUser' not in user.roles: - raise InvalidAuthWebSocket('User is not an AdminUser') + raise ForbiddenWebSocket('User is not an AdminUser') return user @@ -73,13 +74,16 @@ async def test_no_keys_to_decode_with(multi_tenant_app, mock_openid_and_empty_ke with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token()}): pass - assert error.value.reason == 'Unable to verify token, no signing keys found' + assert error.value.reason == str( + {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + ) + assert error.value.code == 1008 @pytest.mark.anyio async def test_iss_callable_raise_error(mock_openid_and_keys): async def issuer_fetcher(tid): - raise InvalidAuthWebSocket(f'Tenant {tid} not a valid tenant') + raise UnauthorizedWebSocket(f'Tenant {tid} not a valid tenant') azure_scheme_overrides = generate_azure_scheme_multi_tenant_object(issuer_fetcher) @@ -87,7 +91,10 @@ async def issuer_fetcher(tid): with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token()}): pass - assert error.value.reason == 'Tenant intility_tenant_id not a valid tenant' + assert error.value.reason == str( + {'error': 'invalid_token', 'message': 'Tenant intility_tenant_id not a valid tenant'} + ) + assert error.value.code == 1008 @pytest.mark.anyio @@ -112,7 +119,8 @@ async def test_normal_user_rejected(multi_tenant_app, mock_openid_and_keys): "/ws/admin", headers={'Authorization': 'Bearer ' + build_access_token_normal_user()} ): pass - assert error.value.reason == 'User is not an AdminUser' + assert error.value.reason == str({'error': 'insufficient_scope', 'message': 'User is not an AdminUser'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -120,7 +128,8 @@ async def test_guest_user_rejected(multi_tenant_app, mock_openid_and_keys): with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token_guest_user()}): pass - assert error.value.reason == 'Guest users not allowed' + assert error.value.reason == str({'error': 'insufficient_scope', 'message': 'Guest users not allowed'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -130,7 +139,8 @@ async def test_invalid_token_claims(multi_tenant_app, mock_openid_and_keys): "/ws", headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()} ): pass - assert error.value.reason == 'Token contains invalid claims' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Token contains invalid claims'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -140,7 +150,10 @@ async def test_no_valid_keys_for_token(multi_tenant_app, mock_openid_and_no_vali "/ws", headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()} ): pass - assert error.value.reason == 'Unable to verify token, no signing keys found' + assert error.value.reason == str( + {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + ) + assert error.value.code == 1008 @pytest.mark.anyio @@ -150,7 +163,8 @@ async def test_no_valid_scopes(multi_tenant_app, mock_openid_and_no_valid_keys): "/ws/scope", headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()} ): pass - assert error.value.reason == 'Required scope missing' + assert error.value.reason == str({'error': 'insufficient_scope', 'message': 'Required scope missing'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -160,7 +174,10 @@ async def test_no_valid_invalid_formatted_scope(multi_tenant_app, mock_openid_an "/ws/scope", headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes(scopes=None)} ): pass - assert error.value.reason == 'Token contains invalid formatted scopes' + assert error.value.reason == str( + {'error': 'insufficient_scope', 'message': 'Token contains invalid formatted scopes'} + ) + assert error.value.code == 1008 @pytest.mark.anyio @@ -168,7 +185,8 @@ async def test_expired_token(multi_tenant_app, mock_openid_and_keys): with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token_expired()}): pass - assert error.value.reason == 'Token signature has expired' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Token signature has expired'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -177,7 +195,8 @@ async def test_evil_token(multi_tenant_app, mock_openid_and_keys): with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_evil_access_token()}): pass - assert error.value.reason == 'Unable to validate token' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Unable to validate token'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -186,7 +205,8 @@ async def test_malformed_token(multi_tenant_app, mock_openid_and_keys): with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'}): pass - assert error.value.reason == 'Invalid token format' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Invalid token format'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -201,17 +221,30 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys): }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} ): pass - assert error.value.reason == 'Invalid token format' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Invalid token format'}) + assert error.value.code == 1008 + + +@pytest.mark.anyio +async def test_exception_raised_extraction(multi_tenant_app, mock_openid_and_keys, mocker): + mocker.patch.object(AzureAuthorizationCodeBearerBase, 'extract_access_token', side_effect=ValueError('oops')) + + with pytest.raises(WebSocketDisconnect) as error: + with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token()}): + pass + assert error.value.reason == str({'error': 'invalid_request', 'message': 'Unable to validate token'}) + assert error.value.code == 1008 @pytest.mark.anyio -async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker): +async def test_exception_raised_validation(multi_tenant_app, mock_openid_and_keys, mocker): mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol')) with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token()}): pass - assert error.value.reason == 'Unable to process token' + assert error.value.reason == str({'error': 'invalid_token', 'message': 'Unable to process token'}) + assert error.value.code == 1008 @pytest.mark.anyio @@ -221,7 +254,8 @@ async def test_exception_raised_unknown(multi_tenant_app, mock_openid_and_keys, with pytest.raises(WebSocketDisconnect) as error: with client.websocket_connect("/ws", headers={'Authorization': 'Bearer ' + build_access_token()}): pass - assert error.value.reason == 'Unable to validate token' + assert error.value.reason == str({'error': 'invalid_request', 'message': 'Unable to validate token'}) + assert error.value.code == 1008 @pytest.mark.anyio diff --git a/tests/multi_tenant_b2c/multi_auth/test_auto_error.py b/tests/multi_tenant_b2c/multi_auth/test_auto_error.py index ab26c0c..7ca19df 100644 --- a/tests/multi_tenant_b2c/multi_auth/test_auto_error.py +++ b/tests/multi_tenant_b2c/multi_auth/test_auto_error.py @@ -10,6 +10,7 @@ async def test_api_key_valid_key(multi_tenant_app, mock_openid_and_keys, freezer ) as ac: response = await ac.get('api/v1/hello-multi-auth-b2c') assert response.json() == {'api_key': True, 'azure_auth': False} + assert response.status_code == 200 @pytest.mark.anyio @@ -18,4 +19,7 @@ async def test_api_key_but_invalid_key(multi_tenant_app, mock_openid_and_keys, f transport=ASGITransport(app=app), base_url='http://test', headers={'TEST-API-KEY': 'JonasIsNotCool'} ) as ac: response = await ac.get('api/v1/hello-multi-auth-b2c') - assert response.json() == {'detail': 'You must either provide a valid bearer token or API key'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'You must either provide a valid bearer token or API key'} + } + assert response.status_code == 401 diff --git a/tests/multi_tenant_b2c/test_multi_tenant.py b/tests/multi_tenant_b2c/test_multi_tenant.py index 6329a30..c2864a7 100644 --- a/tests/multi_tenant_b2c/test_multi_tenant.py +++ b/tests/multi_tenant_b2c/test_multi_tenant.py @@ -120,7 +120,10 @@ async def test_no_keys_to_decode_with(multi_tenant_app, mock_openid_and_empty_ke app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -131,7 +134,8 @@ async def test_normal_user_rejected(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_normal_user()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'User is not an AdminUser'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'User is not an AdminUser'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -156,7 +160,8 @@ async def test_invalid_token_claims(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid claims'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token contains invalid claims'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -167,7 +172,10 @@ async def test_no_valid_keys_for_token(multi_tenant_app, mock_openid_and_no_vali headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -178,7 +186,8 @@ async def test_no_valid_scopes(multi_tenant_app, mock_openid_and_no_valid_keys): headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Required scope missing'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Required scope missing'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -189,7 +198,8 @@ async def test_no_valid_invalid_scope(multi_tenant_app, mock_openid_and_no_valid headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Required scope missing'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Required scope missing'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -200,7 +210,10 @@ async def test_no_valid_invalid_formatted_scope(multi_tenant_app, mock_openid_an headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes(scopes=None)}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid formatted scopes'} + assert response.json() == { + 'detail': {'error': 'insufficient_scope', 'message': 'Token contains invalid formatted scopes'} + } + assert response.status_code == 403 @pytest.mark.anyio @@ -211,7 +224,8 @@ async def test_expired_token(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token signature has expired'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token signature has expired'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -223,7 +237,8 @@ async def test_evil_token(multi_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_evil_access_token()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to validate token'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Unable to validate token'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -233,7 +248,8 @@ async def test_malformed_token(multi_tenant_app, mock_openid_and_keys): app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -248,7 +264,8 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys): }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -261,7 +278,8 @@ async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to process token'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Unable to process token'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -285,4 +303,7 @@ async def test_change_of_keys_works(multi_tenant_app, mock_openid_ok_then_empty, app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: second_resonse = await ac.get('api/v1/hello') - assert second_resonse.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert second_resonse.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert second_resonse.status_code == 401 diff --git a/tests/single_tenant/test_single_tenant.py b/tests/single_tenant/test_single_tenant.py index 3f8399d..6798d65 100644 --- a/tests/single_tenant/test_single_tenant.py +++ b/tests/single_tenant/test_single_tenant.py @@ -119,7 +119,10 @@ async def test_no_keys_to_decode_with(single_tenant_app, mock_openid_and_empty_k app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -130,7 +133,8 @@ async def test_normal_user_rejected(single_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_normal_user()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'User is not an AdminUser'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'User is not an AdminUser'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -141,7 +145,8 @@ async def test_guest_user_rejected(single_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_guest_user()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Guest users not allowed'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Guest users not allowed'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -152,7 +157,8 @@ async def test_invalid_token_claims(single_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid claims'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token contains invalid claims'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -163,7 +169,10 @@ async def test_no_valid_keys_for_token(single_tenant_app, mock_openid_and_no_val headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert response.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert response.status_code == 401 @pytest.mark.anyio @@ -174,7 +183,8 @@ async def test_no_valid_scopes(single_tenant_app, mock_openid_and_no_valid_keys) headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Required scope missing'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Required scope missing'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -185,7 +195,8 @@ async def test_no_valid_invalid_scope(single_tenant_app, mock_openid_and_no_vali headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Required scope missing'} + assert response.json() == {'detail': {'error': 'insufficient_scope', 'message': 'Required scope missing'}} + assert response.status_code == 403 @pytest.mark.anyio @@ -196,7 +207,10 @@ async def test_no_valid_invalid_formatted_scope(single_tenant_app, mock_openid_a headers={'Authorization': 'Bearer ' + build_access_token_invalid_scopes(scopes=None)}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token contains invalid formatted scopes'} + assert response.json() == { + 'detail': {'error': 'insufficient_scope', 'message': 'Token contains invalid formatted scopes'} + } + assert response.status_code == 403 @pytest.mark.anyio @@ -207,7 +221,8 @@ async def test_expired_token(single_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Token signature has expired'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Token signature has expired'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -219,7 +234,8 @@ async def test_evil_token(single_tenant_app, mock_openid_and_keys): headers={'Authorization': 'Bearer ' + build_evil_access_token()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to validate token'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Unable to validate token'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -229,7 +245,8 @@ async def test_malformed_token(single_tenant_app, mock_openid_and_keys): app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -244,7 +261,8 @@ async def test_only_header(single_tenant_app, mock_openid_and_keys): }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -256,7 +274,8 @@ async def test_none_token(single_tenant_app, mock_openid_and_keys, mocker): headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -268,7 +287,8 @@ async def test_exception_raised(single_tenant_app, mock_openid_and_keys, mocker) headers={'Authorization': 'Bearer ' + build_access_token_expired()}, ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to process token'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Unable to process token'}} + assert response.status_code == 401 @pytest.mark.anyio @@ -292,4 +312,7 @@ async def test_change_of_keys_works(single_tenant_app, mock_openid_ok_then_empty app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} ) as ac: second_resonse = await ac.get('api/v1/hello') - assert second_resonse.json() == {'detail': 'Unable to verify token, no signing keys found'} + assert second_resonse.json() == { + 'detail': {'error': 'invalid_token', 'message': 'Unable to verify token, no signing keys found'} + } + assert second_resonse.status_code == 401 diff --git a/tests/test_exception_compat.py b/tests/test_exception_compat.py new file mode 100644 index 0000000..c25504a --- /dev/null +++ b/tests/test_exception_compat.py @@ -0,0 +1,72 @@ +""" +This module tests the exception handling and backwards compatibility of the exceptions module, introduced in +issue https://github.com/intility/fastapi-azure-auth/issues/229. +TODO: Remove this test module in v6.0.0 +""" +import pytest +from fastapi import HTTPException, WebSocketException, status + +from fastapi_azure_auth.exceptions import ( + InvalidAuth, + InvalidAuthHttp, + InvalidAuthWebSocket, + UnauthorizedHttp, + UnauthorizedWebSocket, +) + + +def test_invalid_auth_backwards_compatibility(): + """Test that InvalidAuth maps to correct exceptions and maintains format""" + # Mock HTTP request scope + http_conn = type('HTTPConnection', (), {'scope': {'type': 'http'}})() + + # Mock WebSocket scope + ws_conn = type('HTTPConnection', (), {'scope': {'type': 'websocket'}})() + + # Test HTTP path + http_exc = InvalidAuth("test message", http_conn) + assert isinstance(http_exc, UnauthorizedHttp) + assert isinstance(http_exc, HTTPException) + assert http_exc.status_code == status.HTTP_401_UNAUTHORIZED + assert http_exc.detail == {"error": "invalid_token", "message": "test message"} + + # Test WebSocket path + ws_exc = InvalidAuth("test message", ws_conn) + assert isinstance(ws_exc, UnauthorizedWebSocket) + assert isinstance(ws_exc, WebSocketException) + assert ws_exc.code == status.WS_1008_POLICY_VIOLATION + assert ws_exc.reason == str({"error": "invalid_token", "message": "test message"}) + + +def test_legacy_exception_catching(): + """Test that old exception catching patterns still work""" + # Test HTTP exceptions + http_conn = type('HTTPConnection', (), {'scope': {'type': 'http'}})() + + with pytest.raises((InvalidAuthHttp, UnauthorizedHttp)) as exc_info: + raise InvalidAuth("test message", http_conn) + + assert isinstance(exc_info.value, UnauthorizedHttp) + assert exc_info.value.detail == {"error": "invalid_token", "message": "test message"} + + # Test WebSocket exceptions + ws_conn = type('HTTPConnection', (), {'scope': {'type': 'websocket'}})() + + with pytest.raises((InvalidAuthWebSocket, UnauthorizedWebSocket)) as exc_info: + raise InvalidAuth("test message", ws_conn) + + assert isinstance(exc_info.value, UnauthorizedWebSocket) + assert exc_info.value.reason == str({"error": "invalid_token", "message": "test message"}) + + +def test_new_exceptions_can_be_caught_as_legacy(): + """Test that new exceptions can be caught with legacy catch blocks""" + with pytest.raises((InvalidAuthHttp, UnauthorizedHttp)) as exc_info: + raise UnauthorizedHttp("test message") + + assert exc_info.value.detail == {"error": "invalid_token", "message": "test message"} + + with pytest.raises((InvalidAuthWebSocket, UnauthorizedWebSocket)) as exc_info: + raise UnauthorizedWebSocket("test message") + + assert exc_info.value.reason == str({"error": "invalid_token", "message": "test message"}) diff --git a/tests/test_openapi_scheme.py b/tests/test_openapi_scheme.py index 907f9fb..a48a72b 100644 --- a/tests/test_openapi_scheme.py +++ b/tests/test_openapi_scheme.py @@ -476,4 +476,4 @@ def test_incorrect_token(test_client): def test_token(test_client): response = test_client.get('/api/v1/hello', headers={'Authorization': 'Bearer '}) assert response.status_code == 401, response.text - assert response.json() == {'detail': 'Invalid token format'} + assert response.json() == {'detail': {'error': 'invalid_token', 'message': 'Invalid token format'}}