Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from python-jose to PyJWT #194

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo_project/api/api_v1/endpoints/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any

import httpx
import jwt
from demo_project.api.dependencies import azure_scheme
from demo_project.core.config import settings
from fastapi import APIRouter, Depends, Request
from httpx import AsyncClient
from jose import jwt

router = APIRouter()

Expand Down Expand Up @@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401

# Return all the information to the end user
return (
{'claims': jwt.get_unverified_claims(token=request.state.user.access_token)}
{'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})}
| {'obo_response': obo_response.json()}
| {'graph_response': graph}
)
91 changes: 62 additions & 29 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import inspect
import logging
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Literal, Optional
from warnings import warn

import jwt
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
from fastapi.security.base import SecurityBase
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
from jwt.exceptions import (
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAlgorithmError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidTokenError,
MissingRequiredClaimError,
)
from starlette.requests import Request

from fastapi_azure_auth.exceptions import InvalidAuth
from fastapi_azure_auth.openid_config import OpenIdConfig
from fastapi_azure_auth.user import User
from fastapi_azure_auth.utils import is_guest
from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest

if TYPE_CHECKING: # pragma: no cover
from jwt.algorithms import AllowedPublicKeys
else:
AllowedPublicKeys = Any

log = logging.getLogger('fastapi_azure_auth')

Expand Down Expand Up @@ -145,11 +159,13 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
Extends call to also validate the token.
"""
try:
access_token = await self.oauth(request=request)
access_token = await self.extract_access_token(request)
try:
if access_token is None:
raise Exception('No access token provided')
# Extract header information of the token.
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
header: dict[str, Any] = get_unverified_header(access_token)
claims: dict[str, Any] = get_unverified_claims(access_token)
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format') from error
Expand Down Expand Up @@ -180,48 +196,41 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
try:
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
# We require and validate all fields in an Azure AD token
required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub']
if self.validate_iss:
required_claims.append('iss')

options = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': self.validate_iss,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': True,
'require_iss': self.validate_iss,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': self.leeway,
'require': required_claims,
}
# Validate token
token = jwt.decode(
access_token,
key=key,
algorithms=['RS256'],
audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}',
issuer=iss,
options=options,
)
token = self.validate(access_token=access_token, iss=iss, key=key, options=options)
# Attach the user to the request. Can be accessed through `request.state.user`
user: User = User(
**{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest}
)
request.state.user = user
return user
except JWTClaimsError as error:
except (
InvalidAudienceError,
InvalidIssuerError,
InvalidIssuedAtError,
ImmatureSignatureError,
JonasKs marked this conversation as resolved.
Show resolved Hide resolved
InvalidAlgorithmError,
JonasKs marked this conversation as resolved.
Show resolved Hide resolved
MissingRequiredClaimError,
) as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims') from error
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired') from error
except JWTError as error:
except InvalidTokenError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token') from error
except Exception as error:
Expand All @@ -235,6 +244,30 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
return None
raise

async def extract_access_token(self, request: Request) -> Optional[str]:
"""
Extracts the access token from the request.
"""
return await self.oauth(request=request)

def validate(self, access_token: str, key: AllowedPublicKeys, iss: str, options: Dict[str, Any]) -> Dict[str, Any]:
"""
Validates the token using the provided key and options.
"""
alg = 'RS256'
aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}'
return dict(
jwt.decode(
access_token,
key=key,
algorithms=[alg],
audience=aud,
issuer=iss,
leeway=self.leeway,
options=options,
)
)


class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase):
def __init__(
Expand Down
16 changes: 10 additions & 6 deletions fastapi_azure_auth/openid_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes
import jwt
from fastapi import HTTPException, status
from httpx import AsyncClient
from jose import jwk

if TYPE_CHECKING: # pragma: no cover
from jwt.algorithms import AllowedPublicKeys
else:
AllowedPublicKeys = Any
JonasKs marked this conversation as resolved.
Show resolved Hide resolved

log = logging.getLogger('fastapi_azure_auth')

Expand All @@ -27,7 +31,7 @@ def __init__(
self.config_url = config_url

self.authorization_endpoint: str
self.signing_keys: dict[str, KeyTypes]
self.signing_keys: dict[str, AllowedPublicKeys]
self.token_endpoint: str
self.issuer: str

Expand Down Expand Up @@ -98,6 +102,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = jwk.construct(key, 'RS256')
cert_obj = jwt.PyJWK(key, 'RS256')
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
self.signing_keys[kid] = cert_obj
self.signing_keys[kid] = cert_obj.key
16 changes: 16 additions & 0 deletions fastapi_azure_auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

import jwt


def is_guest(claims: Dict[str, Any]) -> bool:
"""
Expand All @@ -12,3 +14,17 @@ def is_guest(claims: Dict[str, Any]) -> bool:
claims_iss: str = claims.get('iss', '')
idp: str = claims.get('idp', claims_iss)
return idp != claims_iss


def get_unverified_header(access_token: str) -> Dict[str, Any]:
"""
Get header from the access token without verifying the signature
"""
return dict(jwt.get_unverified_header(access_token))


def get_unverified_claims(access_token: str) -> Dict[str, Any]:
"""
Get claims from the access token without verifying the signature
"""
return dict(jwt.decode(access_token, options={'verify_signature': False}))
85 changes: 19 additions & 66 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ classifiers = [
python = "^3.8"
fastapi = ">0.68.0"
cryptography = ">=40.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
httpx = ">0.18.2"
pyjwt = "^2.8.0"


[tool.poetry.group.dev.dependencies]
Expand Down
3 changes: 2 additions & 1 deletion tests/multi_tenant/test_multi_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
from fastapi_azure_auth.exceptions import InvalidAuth


Expand Down Expand Up @@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):

@pytest.mark.anyio
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
async with AsyncClient(
app=app,
base_url='http://test',
Expand Down
Loading