Skip to content

Commit

Permalink
feat: jwt_backends - create backend mechanism and add authlib support (
Browse files Browse the repository at this point in the history
…#41)

Co-authored-by: Konstantin Chernyshev <38007247+k4black@users.noreply.github.com>
Co-authored-by: Konstantin Chernyshev <k4black@ya.ru>
  • Loading branch information
3 people authored May 6, 2024
1 parent 19c6038 commit 0ea6c04
Show file tree
Hide file tree
Showing 18 changed files with 699 additions and 612 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ FastAPI native extension, easy and simple JWT auth
## Installation
You can access package [fastapi-jwt in pypi](https://pypi.org/project/fastapi-jwt/)
```shell
pip install fastapi-jwt
pip install fastapi-jwt[authlib]
# or
pip install fastapi-jwt[python_jose]
```

The fastapi-jwt will choose the backend automatically if library is installed with the following priority:
1. authlib
2. python_jose (deprecated)

## Usage
This library made in fastapi style, so it can be used as standard security features
Expand Down Expand Up @@ -81,7 +86,7 @@ There it is open and maintained [Pull Request #3305](https://github.com/tiangolo
## Requirements

* `fastapi`
* `python-jose[cryptography]`
* `authlib` or `python-jose[cryptography]` (deprecated)

## License
This project is licensed under the terms of the MIT license.
This project is licensed under the terms of the MIT license.
1 change: 1 addition & 0 deletions fastapi_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jwt import * # noqa: F401, F403
from .jwt_backends import * # noqa: F401, F403
159 changes: 58 additions & 101 deletions fastapi_jwt/jwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Set
from typing import Any, Dict, Optional, Set, Type
from uuid import uuid4

from fastapi.exceptions import HTTPException
Expand All @@ -9,13 +9,24 @@
from fastapi.security import APIKeyCookie, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED

try:
from jose import jwt
except ImportError: # pragma: nocover
jwt = None # type: ignore[assignment]
from .jwt_backends import AbstractJWTBackend, authlib_backend, python_jose_backend
from .jwt_backends.abstract_backend import BackendException

DEFAULT_JWT_BACKEND: Optional[Type[AbstractJWTBackend]] = None
if authlib_backend.authlib_jose is not None:
DEFAULT_JWT_BACKEND = authlib_backend.AuthlibJWTBackend
elif python_jose_backend.jose is not None:
DEFAULT_JWT_BACKEND = python_jose_backend.PythonJoseJWTBackend
else: # pragma: nocover
raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'")

def utcnow():

def force_jwt_backend(cls: Type[AbstractJWTBackend]) -> None:
global DEFAULT_JWT_BACKEND
DEFAULT_JWT_BACKEND = cls


def utcnow() -> datetime:
try:
from datetime import UTC
except ImportError: # pragma: nocover
Expand All @@ -27,6 +38,7 @@ def utcnow():


__all__ = [
"force_jwt_backend",
"JwtAuthorizationCredentials",
"JwtAccessBearer",
"JwtAccessCookie",
Expand All @@ -49,15 +61,11 @@ def __getitem__(self, item: str) -> Any:
class JwtAuthBase(ABC):
class JwtAccessCookie(APIKeyCookie):
def __init__(self, *args: Any, **kwargs: Any):
APIKeyCookie.__init__(
self, *args, name="access_token_cookie", auto_error=False, **kwargs
)
APIKeyCookie.__init__(self, *args, name="access_token_cookie", auto_error=False, **kwargs)

class JwtRefreshCookie(APIKeyCookie):
def __init__(self, *args: Any, **kwargs: Any):
APIKeyCookie.__init__(
self, *args, name="refresh_token_cookie", auto_error=False, **kwargs
)
APIKeyCookie.__init__(self, *args, name="refresh_token_cookie", auto_error=False, **kwargs)

class JwtAccessBearer(HTTPBearer):
def __init__(self, *args: Any, **kwargs: Any):
Expand All @@ -72,38 +80,35 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
assert jwt is not None, "python-jose must be installed to use JwtAuth"
if places:
assert places.issubset(
{"header", "cookie"}
), "only 'header'/'cookie' are supported"
algorithm = algorithm.upper()
assert (
hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined]
), f"{algorithm} algorithm is not supported by python-jose library"
assert DEFAULT_JWT_BACKEND is not None, "No JWT backend found, please install 'python-jose' or 'authlib'"

self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm)
self.secret_key = secret_key

self.places = places or {"header"}
assert self.places.issubset({"header", "cookie"}), "only 'header' and/or 'cookie' places are supported"
self.auto_error = auto_error
self.algorithm = algorithm
self.access_expires_delta = access_expires_delta or timedelta(minutes=15)
self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31)

@property
def algorithm(self) -> str:
return self.jwt_backend.algorithm

@classmethod
def from_other(
cls,
other: 'JwtAuthBase',
other: "JwtAuthBase",
secret_key: Optional[str] = None,
auto_error: Optional[bool] = None,
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
) -> 'JwtAuthBase':
) -> "JwtAuthBase":
return cls(
secret_key=secret_key or other.secret_key,
auto_error=auto_error or other.auto_error,
Expand All @@ -112,30 +117,6 @@ def from_other(
refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta,
)

def _decode(self, token: str) -> Optional[Dict[str, Any]]:
try:
payload: Dict[str, Any] = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"leeway": 10},
)
return payload
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
)
else:
return None
except jwt.JWTError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
)
else:
return None

def _generate_payload(
self,
subject: Dict[str, Any],
Expand All @@ -144,7 +125,6 @@ def _generate_payload(
token_type: str,
) -> Dict[str, Any]:
now = utcnow()

return {
"subject": subject.copy(), # main subject
"type": token_type, # 'access' or 'refresh' token
Expand All @@ -165,15 +145,18 @@ async def _get_payload(
# Check token exist
if not token:
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided"
)
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided")
else:
return None

# Try to decode jwt token. auto_error on error
payload = self._decode(token)
return payload
try:
return self.jwt_backend.decode(token, self.secret_key)
except BackendException as e:
if self.auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=str(e))
else:
return None

def create_access_token(
self,
Expand All @@ -183,14 +166,8 @@ def create_access_token(
) -> str:
expires_delta = expires_delta or self.access_expires_delta
unique_identifier = unique_identifier or str(uuid4())
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "access"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "access")
return self.jwt_backend.encode(to_encode, self.secret_key)

def create_refresh_token(
self,
Expand All @@ -200,22 +177,12 @@ def create_refresh_token(
) -> str:
expires_delta = expires_delta or self.refresh_expires_delta
unique_identifier = unique_identifier or str(uuid4())
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "refresh"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "refresh")
return self.jwt_backend.encode(to_encode, self.secret_key)

@staticmethod
def set_access_cookie(
response: Response, access_token: str, expires_delta: Optional[timedelta] = None
) -> None:
seconds_expires: Optional[int] = (
int(expires_delta.total_seconds()) if expires_delta else None
)
def set_access_cookie(response: Response, access_token: str, expires_delta: Optional[timedelta] = None) -> None:
seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None
response.set_cookie(
key="access_token_cookie",
value=access_token,
Expand All @@ -229,9 +196,7 @@ def set_refresh_cookie(
refresh_token: str,
expires_delta: Optional[timedelta] = None,
) -> None:
seconds_expires: Optional[int] = (
int(expires_delta.total_seconds()) if expires_delta else None
)
seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None
response.set_cookie(
key="refresh_token_cookie",
value=refresh_token,
Expand All @@ -241,15 +206,11 @@ def set_refresh_cookie(

@staticmethod
def unset_access_cookie(response: Response) -> None:
response.set_cookie(
key="access_token_cookie", value="", httponly=False, max_age=-1
)
response.set_cookie(key="access_token_cookie", value="", httponly=False, max_age=-1)

@staticmethod
def unset_refresh_cookie(response: Response) -> None:
response.set_cookie(
key="refresh_token_cookie", value="", httponly=True, max_age=-1
)
response.set_cookie(key="refresh_token_cookie", value="", httponly=True, max_age=-1)


class JwtAccess(JwtAuthBase):
Expand All @@ -261,7 +222,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -282,9 +243,7 @@ async def _get_credentials(
payload = await self._get_payload(bearer, cookie)

if payload:
return JwtAuthorizationCredentials(
payload["subject"], payload.get("jti", None)
)
return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None))
return None


Expand All @@ -293,7 +252,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -317,7 +276,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -342,7 +301,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down Expand Up @@ -372,7 +331,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -399,22 +358,20 @@ async def _get_credentials(
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Wrong token: 'type' is not 'refresh'",
detail="Invalid token: 'type' is not 'refresh'",
)
else:
return None

return JwtAuthorizationCredentials(
payload["subject"], payload.get("jti", None)
)
return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None))


class JwtRefreshBearer(JwtRefresh):
def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -438,7 +395,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -463,7 +420,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down
4 changes: 4 additions & 0 deletions fastapi_jwt/jwt_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import abstract_backend, authlib_backend, python_jose_backend # noqa: F401
from .abstract_backend import AbstractJWTBackend # noqa: F401
from .authlib_backend import AuthlibJWTBackend # noqa: F401
from .python_jose_backend import PythonJoseJWTBackend # noqa: F401
Loading

0 comments on commit 0ea6c04

Please sign in to comment.