Skip to content

Commit

Permalink
jwt_backends: create backend mechanism and add authlib support
Browse files Browse the repository at this point in the history
  • Loading branch information
hasB4K committed Feb 27, 2024
1 parent 2f733ce commit 06d3132
Show file tree
Hide file tree
Showing 16 changed files with 697 additions and 397 deletions.
1 change: 1 addition & 0 deletions fastapi_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jwt import * # noqa: F401, F403
from .jwt_backends import * # noqa: F401, F403
91 changes: 33 additions & 58 deletions fastapi_jwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
from fastapi.responses import Response
from fastapi.security import APIKeyCookie, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from .jwt_backends import AuthlibJWTBackend, PythonJoseJWTBackend

try:
from jose import jwt
except ImportError: # pragma: nocover
jwt = None # type: ignore[assignment]

DEFAULT_JWT_BACKEND = None


def define_default_jwt_backend(cls):
global DEFAULT_JWT_BACKEND
DEFAULT_JWT_BACKEND = cls


if AuthlibJWTBackend is not None:
define_default_jwt_backend(AuthlibJWTBackend)
elif PythonJoseJWTBackend is not None:
define_default_jwt_backend(PythonJoseJWTBackend)


def utcnow():
Expand All @@ -27,6 +37,7 @@ def utcnow():


__all__ = [
"define_default_jwt_backend",
"JwtAuthorizationCredentials",
"JwtAccessBearer",
"JwtAccessCookie",
Expand Down Expand Up @@ -72,28 +83,26 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
assert jwt is not None, "python-jose must be installed to use JwtAuth"
self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm)
self.secret_key = secret_key
if places:
assert places.issubset(
{"header", "cookie"}
), "only 'header'/'cookie' are supported"
algorithm = algorithm.upper()
assert (
hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined]
), f"{algorithm} algorithm is not supported by python-jose library"

self.secret_key = secret_key

self.places = places or {"header"}
self.auto_error = auto_error
self.algorithm = algorithm
self.access_expires_delta = access_expires_delta or timedelta(minutes=15)
self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31)

@property
def algorithm(self):
return self.jwt_backend.algorithm

@classmethod
def from_other(
cls,
Expand All @@ -112,30 +121,6 @@ def from_other(
refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta,
)

def _decode(self, token: str) -> Optional[Dict[str, Any]]:
try:
payload: Dict[str, Any] = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"leeway": 10},
)
return payload
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
)
else:
return None
except jwt.JWTError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
)
else:
return None

def _generate_payload(
self,
subject: Dict[str, Any],
Expand All @@ -144,7 +129,6 @@ def _generate_payload(
token_type: str,
) -> Dict[str, Any]:
now = utcnow()

return {
"subject": subject.copy(), # main subject
"type": token_type, # 'access' or 'refresh' token
Expand Down Expand Up @@ -172,8 +156,7 @@ async def _get_payload(
return None

# Try to decode jwt token. auto_error on error
payload = self._decode(token)
return payload
return self.jwt_backend.decode(token, self.secret_key, self.auto_error)

def create_access_token(
self,
Expand All @@ -186,11 +169,7 @@ def create_access_token(
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "access"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
return self.jwt_backend.encode(to_encode, self.secret_key)

def create_refresh_token(
self,
Expand All @@ -203,11 +182,7 @@ def create_refresh_token(
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "refresh"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
return self.jwt_backend.encode(to_encode, self.secret_key)

@staticmethod
def set_access_cookie(
Expand Down Expand Up @@ -261,7 +236,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down Expand Up @@ -293,7 +268,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -317,7 +292,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -342,7 +317,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down Expand Up @@ -372,7 +347,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down Expand Up @@ -414,7 +389,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -438,7 +413,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -463,7 +438,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down
9 changes: 9 additions & 0 deletions fastapi_jwt/jwt_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
try:
from .authlib_backend import AuthlibJWTBackend
except ImportError:
AuthlibJWTBackend = None

try:
from .python_jose_backend import PythonJoseJWTBackend
except ImportError:
PythonJoseJWTBackend = None
31 changes: 31 additions & 0 deletions fastapi_jwt/jwt_backends/abstract_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABCMeta, abstractmethod, abstractproperty
from typing import Any, Dict, Optional, Self



class AbstractJWTBackend(metaclass=ABCMeta):

# simple "SingletonArgs" implementation to keep a JWTBackend per algorithm
_instances = {}

def __new__(cls, algorithm) -> Self:
instance_key = (cls, algorithm)
if instance_key not in cls._instances:
cls._instances[instance_key] = super(AbstractJWTBackend, cls).__new__(cls)
return cls._instances[instance_key]

@abstractmethod
def __init__(self, algorithm) -> None:
pass

@abstractproperty
def default_algorithm(self) -> str:
pass

@abstractmethod
def encode(self, to_encode, secret_key) -> str:
pass

@abstractmethod
def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
pass
51 changes: 51 additions & 0 deletions fastapi_jwt/jwt_backends/authlib_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from fastapi import HTTPException
from typing import Any, Dict, Optional
from starlette.status import HTTP_401_UNAUTHORIZED

from authlib.jose import JsonWebSignature, JsonWebToken
from authlib.jose.errors import (
DecodeError, ExpiredTokenError, InvalidClaimError, InvalidTokenError
)
from .abstract_backend import AbstractJWTBackend


class AuthlibJWTBackend(AbstractJWTBackend):

def __init__(self, algorithm) -> None:
self.algorithm = algorithm if algorithm is not None else self.default_algorithm
# from https://github.com/lepture/authlib/blob/85f9ff/authlib/jose/__init__.py#L45
valid_algorithms = JsonWebSignature.ALGORITHMS_REGISTRY.keys()
assert (
self.algorithm in valid_algorithms
), f"{self.algorithm} algorithm is not supported by authlib"
self.jwt = JsonWebToken(algorithms=[self.algorithm])

@property
def default_algorithm(self) -> str:
return "HS256"

def encode(self, to_encode, secret_key) -> str:
token = self.jwt.encode(header={"alg": self.algorithm}, payload=to_encode, key=secret_key)
return token.decode() # convert to string

def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
try:
payload = self.jwt.decode(token, secret_key)
payload.validate(leeway=10)
return dict(payload)
except ExpiredTokenError as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
)
else:
return None
except (InvalidClaimError,
InvalidTokenError,
DecodeError) as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
)
else:
return None
47 changes: 47 additions & 0 deletions fastapi_jwt/jwt_backends/python_jose_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from fastapi import HTTPException
from typing import Any, Dict, Optional
from starlette.status import HTTP_401_UNAUTHORIZED

from jose import jwt

from .abstract_backend import AbstractJWTBackend


class PythonJoseJWTBackend(AbstractJWTBackend):

def __init__(self, algorithm) -> None:
self.algorithm = algorithm if algorithm is not None else self.default_algorithm
assert (
hasattr(jwt.ALGORITHMS, self.algorithm) is True # type: ignore[attr-defined]
), f"{algorithm} algorithm is not supported by python-jose library"

@property
def default_algorithm(self) -> str:
return jwt.ALGORITHMS.HS256

def encode(self, to_encode, secret_key) -> str:
return jwt.encode(to_encode, secret_key, algorithm=self.algorithm)

def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]:
try:
payload: Dict[str, Any] = jwt.decode(
token,
secret_key,
algorithms=[self.algorithm],
options={"leeway": 10},
)
return payload
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
)
else:
return None
except jwt.JWTError as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
)
else:
return None
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ classifiers = [

dependencies = [
"fastapi >=0.50.0",
"python-jose[cryptography] >=3.3.0"
]


Expand All @@ -37,7 +36,15 @@ documentation = "https://k4black.github.io/fastapi-jwt/"


[project.optional-dependencies]
authlib = [
"Authlib >=1.3.0"
]
python_jose = [
"python-jose[cryptography] >=3.3.0"
]
test = [
"Authlib >=1.3.0",
"python-jose[cryptography] >=3.3.0",
"httpx >=0.23.0,<1.0.0",
"pytest >=7.0.0,<9.0.0",
"pytest-cov >=4.0.0,<5.0.0",
Expand Down
Loading

0 comments on commit 06d3132

Please sign in to comment.