From 0ea6c045ab2821ffa47b4bad15fedb4cebbadab7 Mon Sep 17 00:00:00 2001 From: Mathis Felardos <3902859+hasB4K@users.noreply.github.com> Date: Mon, 6 May 2024 15:09:50 +0200 Subject: [PATCH] feat: jwt_backends - create backend mechanism and add authlib support (#41) Co-authored-by: Konstantin Chernyshev <38007247+k4black@users.noreply.github.com> Co-authored-by: Konstantin Chernyshev --- README.md | 11 +- fastapi_jwt/__init__.py | 1 + fastapi_jwt/jwt.py | 159 +++++-------- fastapi_jwt/jwt_backends/__init__.py | 4 + fastapi_jwt/jwt_backends/abstract_backend.py | 25 ++ fastapi_jwt/jwt_backends/authlib_backend.py | 47 ++++ .../jwt_backends/python_jose_backend.py | 46 ++++ pyproject.toml | 12 +- tests/conftest.py | 10 + tests/mock_datetime_utils.py | 37 +++ tests/test_security_jwt_bearer.py | 95 ++++---- tests/test_security_jwt_bearer_optional.py | 107 ++++----- tests/test_security_jwt_cookie.py | 81 ++++--- tests/test_security_jwt_cookie_optional.py | 93 ++++---- tests/test_security_jwt_general.py | 220 +++++++----------- tests/test_security_jwt_general_optional.py | 212 +++++++---------- tests/test_security_jwt_multiple_places.py | 95 ++++---- tests/test_security_jwt_set_cookie.py | 56 ++--- 18 files changed, 699 insertions(+), 612 deletions(-) create mode 100644 fastapi_jwt/jwt_backends/__init__.py create mode 100644 fastapi_jwt/jwt_backends/abstract_backend.py create mode 100644 fastapi_jwt/jwt_backends/authlib_backend.py create mode 100644 fastapi_jwt/jwt_backends/python_jose_backend.py create mode 100644 tests/conftest.py create mode 100644 tests/mock_datetime_utils.py diff --git a/README.md b/README.md index 90e7c33..3e675dc 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,14 @@ FastAPI native extension, easy and simple JWT auth ## Installation You can access package [fastapi-jwt in pypi](https://pypi.org/project/fastapi-jwt/) ```shell -pip install fastapi-jwt +pip install fastapi-jwt[authlib] +# or +pip install fastapi-jwt[python_jose] ``` +The fastapi-jwt will choose the backend automatically if library is installed with the following priority: +1. authlib +2. python_jose (deprecated) ## Usage This library made in fastapi style, so it can be used as standard security features @@ -81,7 +86,7 @@ There it is open and maintained [Pull Request #3305](https://github.com/tiangolo ## Requirements * `fastapi` -* `python-jose[cryptography]` +* `authlib` or `python-jose[cryptography]` (deprecated) ## License -This project is licensed under the terms of the MIT license. \ No newline at end of file +This project is licensed under the terms of the MIT license. diff --git a/fastapi_jwt/__init__.py b/fastapi_jwt/__init__.py index 25585c5..a8099f8 100644 --- a/fastapi_jwt/__init__.py +++ b/fastapi_jwt/__init__.py @@ -1 +1,2 @@ from .jwt import * # noqa: F401, F403 +from .jwt_backends import * # noqa: F401, F403 diff --git a/fastapi_jwt/jwt.py b/fastapi_jwt/jwt.py index d9041ed..4b86a1f 100644 --- a/fastapi_jwt/jwt.py +++ b/fastapi_jwt/jwt.py @@ -1,6 +1,6 @@ from abc import ABC from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, Optional, Set, Type from uuid import uuid4 from fastapi.exceptions import HTTPException @@ -9,13 +9,24 @@ from fastapi.security import APIKeyCookie, HTTPBearer from starlette.status import HTTP_401_UNAUTHORIZED -try: - from jose import jwt -except ImportError: # pragma: nocover - jwt = None # type: ignore[assignment] +from .jwt_backends import AbstractJWTBackend, authlib_backend, python_jose_backend +from .jwt_backends.abstract_backend import BackendException +DEFAULT_JWT_BACKEND: Optional[Type[AbstractJWTBackend]] = None +if authlib_backend.authlib_jose is not None: + DEFAULT_JWT_BACKEND = authlib_backend.AuthlibJWTBackend +elif python_jose_backend.jose is not None: + DEFAULT_JWT_BACKEND = python_jose_backend.PythonJoseJWTBackend +else: # pragma: nocover + raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'") -def utcnow(): + +def force_jwt_backend(cls: Type[AbstractJWTBackend]) -> None: + global DEFAULT_JWT_BACKEND + DEFAULT_JWT_BACKEND = cls + + +def utcnow() -> datetime: try: from datetime import UTC except ImportError: # pragma: nocover @@ -27,6 +38,7 @@ def utcnow(): __all__ = [ + "force_jwt_backend", "JwtAuthorizationCredentials", "JwtAccessBearer", "JwtAccessCookie", @@ -49,15 +61,11 @@ def __getitem__(self, item: str) -> Any: class JwtAuthBase(ABC): class JwtAccessCookie(APIKeyCookie): def __init__(self, *args: Any, **kwargs: Any): - APIKeyCookie.__init__( - self, *args, name="access_token_cookie", auto_error=False, **kwargs - ) + APIKeyCookie.__init__(self, *args, name="access_token_cookie", auto_error=False, **kwargs) class JwtRefreshCookie(APIKeyCookie): def __init__(self, *args: Any, **kwargs: Any): - APIKeyCookie.__init__( - self, *args, name="refresh_token_cookie", auto_error=False, **kwargs - ) + APIKeyCookie.__init__(self, *args, name="refresh_token_cookie", auto_error=False, **kwargs) class JwtAccessBearer(HTTPBearer): def __init__(self, *args: Any, **kwargs: Any): @@ -72,38 +80,35 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): - assert jwt is not None, "python-jose must be installed to use JwtAuth" - if places: - assert places.issubset( - {"header", "cookie"} - ), "only 'header'/'cookie' are supported" - algorithm = algorithm.upper() - assert ( - hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined] - ), f"{algorithm} algorithm is not supported by python-jose library" + assert DEFAULT_JWT_BACKEND is not None, "No JWT backend found, please install 'python-jose' or 'authlib'" + self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm) self.secret_key = secret_key self.places = places or {"header"} + assert self.places.issubset({"header", "cookie"}), "only 'header' and/or 'cookie' places are supported" self.auto_error = auto_error - self.algorithm = algorithm self.access_expires_delta = access_expires_delta or timedelta(minutes=15) self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31) + @property + def algorithm(self) -> str: + return self.jwt_backend.algorithm + @classmethod def from_other( cls, - other: 'JwtAuthBase', + other: "JwtAuthBase", secret_key: Optional[str] = None, auto_error: Optional[bool] = None, algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, - ) -> 'JwtAuthBase': + ) -> "JwtAuthBase": return cls( secret_key=secret_key or other.secret_key, auto_error=auto_error or other.auto_error, @@ -112,30 +117,6 @@ def from_other( refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta, ) - def _decode(self, token: str) -> Optional[Dict[str, Any]]: - try: - payload: Dict[str, Any] = jwt.decode( - token, - self.secret_key, - algorithms=[self.algorithm], - options={"leeway": 10}, - ) - return payload - except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined] - if self.auto_error: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}" - ) - else: - return None - except jwt.JWTError as e: # type: ignore[attr-defined] - if self.auto_error: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}" - ) - else: - return None - def _generate_payload( self, subject: Dict[str, Any], @@ -144,7 +125,6 @@ def _generate_payload( token_type: str, ) -> Dict[str, Any]: now = utcnow() - return { "subject": subject.copy(), # main subject "type": token_type, # 'access' or 'refresh' token @@ -165,15 +145,18 @@ async def _get_payload( # Check token exist if not token: if self.auto_error: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided" - ) + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided") else: return None # Try to decode jwt token. auto_error on error - payload = self._decode(token) - return payload + try: + return self.jwt_backend.decode(token, self.secret_key) + except BackendException as e: + if self.auto_error: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=str(e)) + else: + return None def create_access_token( self, @@ -183,14 +166,8 @@ def create_access_token( ) -> str: expires_delta = expires_delta or self.access_expires_delta unique_identifier = unique_identifier or str(uuid4()) - to_encode = self._generate_payload( - subject, expires_delta, unique_identifier, "access" - ) - - jwt_encoded: str = jwt.encode( - to_encode, self.secret_key, algorithm=self.algorithm - ) - return jwt_encoded + to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "access") + return self.jwt_backend.encode(to_encode, self.secret_key) def create_refresh_token( self, @@ -200,22 +177,12 @@ def create_refresh_token( ) -> str: expires_delta = expires_delta or self.refresh_expires_delta unique_identifier = unique_identifier or str(uuid4()) - to_encode = self._generate_payload( - subject, expires_delta, unique_identifier, "refresh" - ) - - jwt_encoded: str = jwt.encode( - to_encode, self.secret_key, algorithm=self.algorithm - ) - return jwt_encoded + to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "refresh") + return self.jwt_backend.encode(to_encode, self.secret_key) @staticmethod - def set_access_cookie( - response: Response, access_token: str, expires_delta: Optional[timedelta] = None - ) -> None: - seconds_expires: Optional[int] = ( - int(expires_delta.total_seconds()) if expires_delta else None - ) + def set_access_cookie(response: Response, access_token: str, expires_delta: Optional[timedelta] = None) -> None: + seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None response.set_cookie( key="access_token_cookie", value=access_token, @@ -229,9 +196,7 @@ def set_refresh_cookie( refresh_token: str, expires_delta: Optional[timedelta] = None, ) -> None: - seconds_expires: Optional[int] = ( - int(expires_delta.total_seconds()) if expires_delta else None - ) + seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None response.set_cookie( key="refresh_token_cookie", value=refresh_token, @@ -241,15 +206,11 @@ def set_refresh_cookie( @staticmethod def unset_access_cookie(response: Response) -> None: - response.set_cookie( - key="access_token_cookie", value="", httponly=False, max_age=-1 - ) + response.set_cookie(key="access_token_cookie", value="", httponly=False, max_age=-1) @staticmethod def unset_refresh_cookie(response: Response) -> None: - response.set_cookie( - key="refresh_token_cookie", value="", httponly=True, max_age=-1 - ) + response.set_cookie(key="refresh_token_cookie", value="", httponly=True, max_age=-1) class JwtAccess(JwtAuthBase): @@ -261,7 +222,7 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -282,9 +243,7 @@ async def _get_credentials( payload = await self._get_payload(bearer, cookie) if payload: - return JwtAuthorizationCredentials( - payload["subject"], payload.get("jti", None) - ) + return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None)) return None @@ -293,7 +252,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -317,7 +276,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -342,7 +301,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -372,7 +331,7 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -399,14 +358,12 @@ async def _get_credentials( if self.auto_error: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, - detail="Wrong token: 'type' is not 'refresh'", + detail="Invalid token: 'type' is not 'refresh'", ) else: return None - return JwtAuthorizationCredentials( - payload["subject"], payload.get("jti", None) - ) + return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None)) class JwtRefreshBearer(JwtRefresh): @@ -414,7 +371,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -438,7 +395,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -463,7 +420,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): diff --git a/fastapi_jwt/jwt_backends/__init__.py b/fastapi_jwt/jwt_backends/__init__.py new file mode 100644 index 0000000..59dd097 --- /dev/null +++ b/fastapi_jwt/jwt_backends/__init__.py @@ -0,0 +1,4 @@ +from . import abstract_backend, authlib_backend, python_jose_backend # noqa: F401 +from .abstract_backend import AbstractJWTBackend # noqa: F401 +from .authlib_backend import AuthlibJWTBackend # noqa: F401 +from .python_jose_backend import PythonJoseJWTBackend # noqa: F401 diff --git a/fastapi_jwt/jwt_backends/abstract_backend.py b/fastapi_jwt/jwt_backends/abstract_backend.py new file mode 100644 index 0000000..c15f0c0 --- /dev/null +++ b/fastapi_jwt/jwt_backends/abstract_backend.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + + +class BackendException(Exception): # pragma: no cover + pass + + +class AbstractJWTBackend(ABC): # pragma: no cover + @abstractmethod + def __init__(self, algorithm: Optional[str] = None) -> None: + raise NotImplementedError + + @property + @abstractmethod + def algorithm(self) -> str: + raise NotImplementedError + + @abstractmethod + def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str: + raise NotImplementedError + + @abstractmethod + def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]: + raise NotImplementedError diff --git a/fastapi_jwt/jwt_backends/authlib_backend.py b/fastapi_jwt/jwt_backends/authlib_backend.py new file mode 100644 index 0000000..780992e --- /dev/null +++ b/fastapi_jwt/jwt_backends/authlib_backend.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, Optional + +try: + import authlib.jose as authlib_jose + import authlib.jose.errors as authlib_jose_errors +except ImportError: # pragma: no cover + authlib_jose = None + +from .abstract_backend import AbstractJWTBackend, BackendException + + +class AuthlibJWTBackend(AbstractJWTBackend): + def __init__(self, algorithm: Optional[str] = None) -> None: + assert authlib_jose is not None, "To use AuthlibJWTBackend, you need to install authlib" + + self._algorithm = algorithm or self.default_algorithm + # from https://github.com/lepture/authlib/blob/85f9ff/authlib/jose/__init__.py#L45 + assert ( + self._algorithm in authlib_jose.JsonWebSignature.ALGORITHMS_REGISTRY.keys() + ), f"{self._algorithm} algorithm is not supported by authlib" + self.jwt = authlib_jose.JsonWebToken(algorithms=[self._algorithm]) + + @property + def default_algorithm(self) -> str: + return "HS256" + + @property + def algorithm(self) -> str: + return self._algorithm + + def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str: + token = self.jwt.encode(header={"alg": self.algorithm}, payload=to_encode, key=secret_key) + return token.decode() # convert to string + + def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]: + try: + payload = self.jwt.decode(token, secret_key) + payload.validate(leeway=10) + return dict(payload) + except authlib_jose_errors.ExpiredTokenError as e: + raise BackendException(f"Token time expired: {e}") + except ( + authlib_jose_errors.InvalidClaimError, + authlib_jose_errors.InvalidTokenError, + authlib_jose_errors.DecodeError, + ) as e: + raise BackendException(f"Invalid token: {e}") diff --git a/fastapi_jwt/jwt_backends/python_jose_backend.py b/fastapi_jwt/jwt_backends/python_jose_backend.py new file mode 100644 index 0000000..4abe9ad --- /dev/null +++ b/fastapi_jwt/jwt_backends/python_jose_backend.py @@ -0,0 +1,46 @@ +import warnings +from typing import Any, Dict, Optional + +try: + import jose + import jose.jwt +except ImportError: # pragma: no cover + jose = None # type: ignore + +from .abstract_backend import AbstractJWTBackend, BackendException + + +class PythonJoseJWTBackend(AbstractJWTBackend): + def __init__(self, algorithm: Optional[str] = None) -> None: + assert jose is not None, "To use PythonJoseJWTBackend, you need to install python-jose" + warnings.warn("PythonJoseJWTBackend is deprecated as python-jose library is not maintained anymore.") + + self._algorithm = algorithm or self.default_algorithm + assert ( + hasattr(jose.jwt.ALGORITHMS, self._algorithm) is True # type: ignore[attr-defined] + ), f"{algorithm} algorithm is not supported by python-jose library" + + @property + def default_algorithm(self) -> str: + return jose.jwt.ALGORITHMS.HS256 # type: ignore[attr-defined] + + @property + def algorithm(self) -> str: + return self._algorithm + + def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str: + return jose.jwt.encode(to_encode, secret_key, algorithm=self._algorithm) + + def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]: + try: + payload: Dict[str, Any] = jose.jwt.decode( + token, + secret_key, + algorithms=[self._algorithm], + options={"leeway": 10}, + ) + return payload + except jose.jwt.ExpiredSignatureError as e: # type: ignore[attr-defined] + raise BackendException(f"Token time expired: {e}") + except jose.jwt.JWTError as e: # type: ignore[attr-defined] + raise BackendException(f"Invalid token: {e}") diff --git a/pyproject.toml b/pyproject.toml index c56f322..b754a8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ dependencies = [ "fastapi >=0.50.0", - "python-jose[cryptography] >=3.3.0" ] @@ -37,7 +36,15 @@ documentation = "https://k4black.github.io/fastapi-jwt/" [project.optional-dependencies] +authlib = [ + "Authlib >=1.3.0" +] +python_jose = [ + "python-jose[cryptography] >=3.3.0" +] test = [ + "Authlib >=1.3.0", + "python-jose[cryptography] >=3.3.0", "httpx >=0.23.0,<1.0.0", "pytest >=7.0.0,<9.0.0", "pytest-cov >=4.0.0,<6.0.0", @@ -64,8 +71,7 @@ docs = [ [tool.setuptools.dynamic] version = {file = "VERSION"} - -[mypy] +[tool.mypy] ignore_missing_imports = true no_incremental = true disallow_untyped_defs = true diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0966f9b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +from typing import Type + +import pytest + +from fastapi_jwt.jwt_backends import AbstractJWTBackend, AuthlibJWTBackend, PythonJoseJWTBackend + + +@pytest.fixture(params=[PythonJoseJWTBackend, AuthlibJWTBackend]) +def jwt_backend(request: pytest.FixtureRequest) -> Type[AbstractJWTBackend]: + return request.param diff --git a/tests/mock_datetime_utils.py b/tests/mock_datetime_utils.py new file mode 100644 index 0000000..a35e576 --- /dev/null +++ b/tests/mock_datetime_utils.py @@ -0,0 +1,37 @@ +import datetime +import time + +from fastapi_jwt import AuthlibJWTBackend, PythonJoseJWTBackend + +_time = time.time +_now = datetime.datetime.now +_utcnow = datetime.datetime.utcnow + + +def create_datetime_mock(**timedelta_kwargs): + class _FakeDateTime(datetime.datetime): # pragma: no cover + @staticmethod + def now(**kwargs): + return _now() + datetime.timedelta(**timedelta_kwargs) + + @staticmethod + def utcnow(**kwargs): + return _utcnow() + datetime.timedelta(**timedelta_kwargs) + + return _FakeDateTime + + +def create_time_time_mock(**kwargs): + def _fake_time_time(): + return _time() + datetime.timedelta(**kwargs).total_seconds() + + return _fake_time_time + + +def mock_now_for_backend(mocker, jwt_backend, **kwargs): + if jwt_backend is AuthlibJWTBackend: + mocker.patch("authlib.jose.rfc7519.claims.time.time", create_time_time_mock(**kwargs)) + elif jwt_backend is PythonJoseJWTBackend: + mocker.patch("jose.jwt.datetime", create_datetime_mock(**kwargs)) + else: + raise Exception("Invalid Backend") diff --git a/tests/test_security_jwt_bearer.py b/tests/test_security_jwt_bearer.py index 104917c..c927a80 100644 --- a/tests/test_security_jwt_bearer.py +++ b/tests/test_security_jwt_bearer.py @@ -1,40 +1,43 @@ +from typing import Type + from fastapi import FastAPI, Security from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer - -app = FastAPI() - -access_security = JwtAccessBearer(secret_key="secret_key") -refresh_security = JwtRefreshBearer(secret_key="secret_key") +from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + access_security = JwtAccessBearer(secret_key="secret_key") + refresh_security = JwtRefreshBearer(secret_key="secret_key") - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -88,70 +91,72 @@ def read_current_user( } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_bearer(): +def test_security_jwt_access_bearer(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong(): - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong_access_token"} - ) +def test_security_jwt_access_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong_access_token"}) assert response.status_code == 401, response.text -def test_security_jwt_access_bearer_no_credentials(): +def test_security_jwt_access_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_access_bearer_incorrect_scheme_credentials(): +def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} # assert response.json() == {"detail": "Invalid authentication credentials"} -def test_security_jwt_refresh_bearer(): +def test_security_jwt_refresh_bearer(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong(): - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} - ) +def test_security_jwt_refresh_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong_refresh_token"}) assert response.status_code == 401, response.text -def test_security_jwt_refresh_bearer_no_credentials(): +def test_security_jwt_refresh_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(): +def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh", headers={"Authorization": "Basic notreally"}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} diff --git a/tests/test_security_jwt_bearer_optional.py b/tests/test_security_jwt_bearer_optional.py index 5029ff6..67ffe94 100644 --- a/tests/test_security_jwt_bearer_optional.py +++ b/tests/test_security_jwt_bearer_optional.py @@ -1,49 +1,50 @@ -from typing import Optional +from typing import Optional, Type from fastapi import FastAPI, Security from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -app = FastAPI() -access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshBearer(secret_key="secret_key", auto_error=False) +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshBearer(secret_key="secret_key", auto_error=False) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) - return {"access_token": access_token, "refresh_token": refresh_token} + return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} + return TestClient(app) -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"username": credentials["username"], "role": credentials["role"]} - - -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -97,71 +98,73 @@ def read_current_user( } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_bearer(): +def test_security_jwt_access_bearer(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong(): - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong_access_token"} - ) +def test_security_jwt_access_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong_access_token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_bearer_no_credentials(): +def test_security_jwt_access_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_bearer_incorrect_scheme_credentials(): +def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer(): +def test_security_jwt_refresh_bearer(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong(): - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} - ) +def test_security_jwt_refresh_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong_refresh_token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer_no_credentials(): +def test_security_jwt_refresh_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(): +def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh", headers={"Authorization": "Basic notreally"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_jwt_cookie.py b/tests/test_security_jwt_cookie.py index d506eba..51a0f99 100644 --- a/tests/test_security_jwt_cookie.py +++ b/tests/test_security_jwt_cookie.py @@ -1,40 +1,43 @@ +from typing import Type + from fastapi import FastAPI, Security from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie - -app = FastAPI() - -access_security = JwtAccessCookie(secret_key="secret_key") -refresh_security = JwtRefreshCookie(secret_key="secret_key") +from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + access_security = JwtAccessCookie(secret_key="secret_key") + refresh_security = JwtRefreshCookie(secret_key="secret_key") - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -96,18 +99,21 @@ def read_current_user( } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_cookie(): +def test_security_jwt_access_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get("/users/me", cookies={"access_token_cookie": access_token}) @@ -115,21 +121,22 @@ def test_security_jwt_access_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_cookie_wrong(): - response = client.get( - "/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"} - ) +def test_security_jwt_access_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.get("/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"}) assert response.status_code == 401, response.text -def test_security_jwt_access_cookie_no_credentials(): +def test_security_jwt_access_cookie_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) client.cookies.clear() response = client.get("/users/me", cookies={}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_cookie(): +def test_security_jwt_refresh_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) client.cookies.clear() refresh_token = client.post("/auth").json()["refresh_token"] @@ -137,14 +144,14 @@ def test_security_jwt_refresh_cookie(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_cookie_wrong(): - response = client.post( - "/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"} - ) +def test_security_jwt_refresh_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.post("/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"}) assert response.status_code == 401, response.text -def test_security_jwt_refresh_cookie_no_credentials(): +def test_security_jwt_refresh_cookie_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) client.cookies.clear() response = client.post("/refresh", cookies={}) assert response.status_code == 401, response.text diff --git a/tests/test_security_jwt_cookie_optional.py b/tests/test_security_jwt_cookie_optional.py index 7cc2f51..7a03c91 100644 --- a/tests/test_security_jwt_cookie_optional.py +++ b/tests/test_security_jwt_cookie_optional.py @@ -1,50 +1,51 @@ -from typing import Optional +from typing import Optional, Type from fastapi import FastAPI, Security from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie +from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -app = FastAPI() -access_security = JwtAccessCookie(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshCookie(secret_key="secret_key", auto_error=False) +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessCookie(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshCookie(secret_key="secret_key", auto_error=False) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) - return {"access_token": access_token, "refresh_token": refresh_token} + return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} + return TestClient(app) - return {"username": credentials["username"], "role": credentials["role"]} - - -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -106,18 +107,21 @@ def read_current_user( } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_cookie(): +def test_security_jwt_access_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) client.cookies.clear() access_token = client.post("/auth").json()["access_token"] @@ -126,36 +130,37 @@ def test_security_jwt_access_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_cookie_wrong(): - response = client.get( - "/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"} - ) +def test_security_jwt_access_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.get("/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_cookie_no_credentials(): +def test_security_jwt_access_cookie_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me", cookies={}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_cookie(): +def test_security_jwt_refresh_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post("/refresh", cookies={"refresh_token_cookie": refresh_token}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_cookie_wrong(): - response = client.post( - "/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"} - ) +def test_security_jwt_refresh_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) + response = client.post("/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_cookie_no_credentials(): +def test_security_jwt_refresh_cookie_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh", cookies={}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_jwt_general.py b/tests/test_security_jwt_general.py index 62a160e..5e281dc 100644 --- a/tests/test_security_jwt_general.py +++ b/tests/test_security_jwt_general.py @@ -1,84 +1,62 @@ -import datetime -from typing import Set +from typing import Set, Type from uuid import uuid4 from fastapi import FastAPI, Security from fastapi.testclient import TestClient from pytest_mock import MockerFixture -from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -app = FastAPI() +from .mock_datetime_utils import mock_now_for_backend -access_security = JwtAccessBearer(secret_key="secret_key") -refresh_security = JwtRefreshBearer.from_other(access_security) +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() -unique_identifiers_database: Set[str] = set() + access_security = JwtAccessBearer(secret_key="secret_key") + refresh_security = JwtRefreshBearer.from_other(access_security) + unique_identifiers_database: Set[str] = set() + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) + access_token = access_security.create_access_token(subject=subject, unique_identifier=unique_identifier) + refresh_token = access_security.create_refresh_token(subject=subject) - access_token = access_security.create_access_token( - subject=subject, unique_identifier=unique_identifier - ) - refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) + access_token = refresh_security.create_access_token( + subject=credentials.subject, + unique_identifier=unique_identifier, + ) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) + return {"access_token": access_token, "refresh_token": refresh_token} - access_token = refresh_security.create_access_token( - subject=credentials.subject, unique_identifier=unique_identifier, - ) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.get("/auth/meta") + def get_token_meta( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"jti": credentials.jti} + return TestClient(app), unique_identifiers_database -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} - - -@app.get("/auth/meta") -def get_token_meta( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"jti": credentials.jti} - - -class _FakeDateTimeShort(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(minutes=3) - - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(minutes=3) - - -class _FakeDateTimeLong(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(days=42) - - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(days=42) - - -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -145,137 +123,113 @@ def utcnow(**kwargs): } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_token(): +def test_security_jwt_access_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_token_wrong(): - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong_access_token"} - ) +def test_security_jwt_access_token_wrong(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong_access_token"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong.access.token"} - ) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong.access.token"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_access_token_changed(): +def test_security_jwt_access_token_changed(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] access_token = access_token.split(".")[0] + ".wrong." + access_token.split(".")[-1] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_access_token_expiration(mocker: MockerFixture): +def test_security_jwt_access_token_expiration(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeShort) # 3 min left - - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, minutes=3) # 3 min left + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith( - "Token time expired: Signature has expired" - ) + assert response.json()["detail"].startswith("Token time expired:") -def test_security_jwt_refresh_token(): +def test_security_jwt_refresh_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_token_wrong(): - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} - ) +def test_security_jwt_refresh_token_wrong(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong_refresh_token"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong.refresh.token"} - ) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong.refresh.token"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_refresh_token_using_access_token(): +def test_security_jwt_refresh_token_using_access_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token: 'type' is not 'refresh'") + assert response.json()["detail"].startswith("Invalid token: 'type' is not 'refresh'") -def test_security_jwt_refresh_token_changed(): +def test_security_jwt_refresh_token_changed(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - refresh_token = ( - refresh_token.split(".")[0] + ".wrong." + refresh_token.split(".")[-1] - ) + refresh_token = refresh_token.split(".")[0] + ".wrong." + refresh_token.split(".")[-1] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_refresh_token_expired(mocker: MockerFixture): +def test_security_jwt_refresh_token_expired(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith( - "Token time expired: Signature has expired" - ) + assert response.json()["detail"].startswith("Token time expired:") -def test_security_jwt_custom_jti(): +def test_security_jwt_custom_jti(jwt_backend: Type[AbstractJWTBackend]): + client, unique_identifiers_database = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/auth/meta", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/auth/meta", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json()["jti"] in unique_identifiers_database diff --git a/tests/test_security_jwt_general_optional.py b/tests/test_security_jwt_general_optional.py index ab32257..86b1f29 100644 --- a/tests/test_security_jwt_general_optional.py +++ b/tests/test_security_jwt_general_optional.py @@ -1,93 +1,71 @@ -import datetime -from typing import Optional, Set +from typing import Optional, Set, Type from uuid import uuid4 from fastapi import FastAPI, Security from fastapi.testclient import TestClient from pytest_mock import MockerFixture -from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -app = FastAPI() +from .mock_datetime_utils import mock_now_for_backend -access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshBearer.from_other(access_security) +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() -unique_identifiers_database: Set[str] = set() + access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshBearer.from_other(access_security) + unique_identifiers_database: Set[str] = set() + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) + access_token = access_security.create_access_token(subject=subject, unique_identifier=unique_identifier) + refresh_token = access_security.create_refresh_token(subject=subject) - access_token = access_security.create_access_token( - subject=subject, unique_identifier=unique_identifier - ) - refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} + access_token = refresh_security.create_access_token( + subject=credentials.subject, + unique_identifier=unique_identifier, + ) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) + return {"access_token": access_token, "refresh_token": refresh_token} - access_token = refresh_security.create_access_token( - subject=credentials.subject, unique_identifier=unique_identifier, - ) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} - return {"access_token": access_token, "refresh_token": refresh_token} - - -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"username": credentials["username"], "role": credentials["role"]} - - -@app.get("/auth/meta") -def get_token_meta( + @app.get("/auth/meta") + def get_token_meta( credentials: JwtAuthorizationCredentials = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"jti": credentials.jti} - - -class _FakeDateTimeShort(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(minutes=3) - - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(minutes=3) + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"jti": credentials.jti} + return TestClient(app), unique_identifiers_database -class _FakeDateTimeLong(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(days=42) - - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(days=42) - - -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -154,135 +132,115 @@ def utcnow(**kwargs): } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_token(): +def test_security_jwt_access_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_token_wrong(): - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong_access_token"} - ) +def test_security_jwt_access_token_wrong(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong_access_token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} - response = client.get( - "/users/me", headers={"Authorization": "Bearer wrong.access.token"} - ) + response = client.get("/users/me", headers={"Authorization": "Bearer wrong.access.token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_token_changed(): +def test_security_jwt_access_token_changed(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] access_token = access_token.split(".")[0] + ".wrong." + access_token.split(".")[-1] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_token_expiration(mocker: MockerFixture): +def test_security_jwt_access_token_expiration(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeShort) # 3 min left - - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, minutes=3) # 3 min left + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token(): +def test_security_jwt_refresh_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text assert "msg" not in response.json() -def test_security_jwt_refresh_token_wrong(): - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} - ) +def test_security_jwt_refresh_token_wrong(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong_refresh_token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} - response = client.post( - "/refresh", headers={"Authorization": "Bearer wrong.refresh.token"} - ) + response = client.post("/refresh", headers={"Authorization": "Bearer wrong.refresh.token"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_using_access_token(): +def test_security_jwt_refresh_token_using_access_token(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_changed(): +def test_security_jwt_refresh_token_changed(jwt_backend: Type[AbstractJWTBackend]): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - refresh_token = ( - refresh_token.split(".")[0] + ".wrong." + refresh_token.split(".")[-1] - ) + refresh_token = refresh_token.split(".")[0] + ".wrong." + refresh_token.split(".")[-1] - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_expired(mocker: MockerFixture): +def test_security_jwt_refresh_token_expired(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - - response = client.post( - "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} - ) + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left + response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_custom_jti(): +def test_security_jwt_custom_jti(jwt_backend: Type[AbstractJWTBackend]): + client, unique_identifiers_database = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/auth/meta", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/auth/meta", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json()["jti"] in unique_identifiers_database diff --git a/tests/test_security_jwt_multiple_places.py b/tests/test_security_jwt_multiple_places.py index af60bb8..d2a8937 100644 --- a/tests/test_security_jwt_multiple_places.py +++ b/tests/test_security_jwt_multiple_places.py @@ -1,40 +1,43 @@ +from typing import Type + from fastapi import FastAPI, Security from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessBearerCookie, JwtAuthorizationCredentials, JwtRefreshBearerCookie - -app = FastAPI() - -access_security = JwtAccessBearerCookie(secret_key="secret_key") -refresh_security = JwtRefreshBearerCookie(secret_key="secret_key") +from fastapi_jwt import JwtAccessBearerCookie, JwtAuthorizationCredentials, JwtRefreshBearerCookie, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + access_security = JwtAccessBearerCookie(secret_key="secret_key") + refresh_security = JwtRefreshBearerCookie(secret_key="secret_key") - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -98,13 +101,15 @@ def read_current_user( } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_both_correct(): +def test_security_jwt_access_both_correct(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -116,7 +121,8 @@ def test_security_jwt_access_both_correct(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_only_cookie(): +def test_security_jwt_access_only_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get("/users/me", cookies={"access_token_cookie": access_token}) @@ -124,17 +130,17 @@ def test_security_jwt_access_only_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_only_bearer(): +def test_security_jwt_access_only_bearer(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - response = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) + response = client.get("/users/me", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong_cookie_correct(): +def test_security_jwt_access_bearer_wrong_cookie_correct(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -143,10 +149,11 @@ def test_security_jwt_access_bearer_wrong_cookie_correct(): cookies={"access_token_cookie": access_token}, ) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_access_bearer_correct_cookie_wrong(): +def test_security_jwt_access_bearer_correct_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -158,20 +165,23 @@ def test_security_jwt_access_bearer_correct_cookie_wrong(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_both_no_credentials(): +def test_security_jwt_access_both_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_only_cookie(): +def test_security_jwt_refresh_only_cookie(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post("/refresh", cookies={"refresh_token_cookie": refresh_token}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_correct_cookie_wrong(): +def test_security_jwt_refresh_bearer_correct_cookie_wrong(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -182,7 +192,8 @@ def test_security_jwt_refresh_bearer_correct_cookie_wrong(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong_cookie_correct(): +def test_security_jwt_refresh_bearer_wrong_cookie_correct(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -191,20 +202,22 @@ def test_security_jwt_refresh_bearer_wrong_cookie_correct(): cookies={"refresh_token_cookie": refresh_token}, ) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token:") + assert response.json()["detail"].startswith("Invalid token:") -def test_security_jwt_refresh_cookie_wrong_using_access_token(): +def test_security_jwt_refresh_cookie_wrong_using_access_token(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token response = client.post("/refresh", cookies={"refresh_token_cookie": access_token}) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith("Wrong token: 'type' is not 'refresh'") + assert response.json()["detail"].startswith("Invalid token: 'type' is not 'refresh'") -def test_security_jwt_refresh_both_no_credentials(): +def test_security_jwt_refresh_both_no_credentials(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} diff --git a/tests/test_security_jwt_set_cookie.py b/tests/test_security_jwt_set_cookie.py index 6057cbc..a86e68e 100644 --- a/tests/test_security_jwt_set_cookie.py +++ b/tests/test_security_jwt_set_cookie.py @@ -1,36 +1,40 @@ +from typing import Type + from fastapi import FastAPI, Response from fastapi.testclient import TestClient -from fastapi_jwt import JwtAccessCookie, JwtRefreshCookie - -app = FastAPI() +from fastapi_jwt import JwtAccessCookie, JwtRefreshCookie, force_jwt_backend +from fastapi_jwt.jwt_backends import AbstractJWTBackend -access_security = JwtAccessCookie(secret_key="secret_key") -refresh_security = JwtRefreshCookie(secret_key="secret_key") +def create_example_client(jwt_backend: Type[AbstractJWTBackend]): + force_jwt_backend(jwt_backend) + app = FastAPI() -@app.post("/auth") -def auth(response: Response): - subject = {"username": "username", "role": "user"} + access_security = JwtAccessCookie(secret_key="secret_key") + refresh_security = JwtRefreshCookie(secret_key="secret_key") - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(response: Response): + subject = {"username": "username", "role": "user"} - access_security.set_access_cookie(response, access_token) - refresh_security.set_refresh_cookie(response, refresh_token) + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) - return {"access_token": access_token, "refresh_token": refresh_token} + access_security.set_access_cookie(response, access_token) + refresh_security.set_refresh_cookie(response, refresh_token) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.delete("/auth") -def logout(response: Response): - access_security.unset_access_cookie(response) - refresh_security.unset_refresh_cookie(response) + @app.delete("/auth") + def logout(response: Response): + access_security.unset_access_cookie(response) + refresh_security.unset_refresh_cookie(response) - return {"msg": "Successful logout"} + return {"msg": "Successful logout"} + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -62,13 +66,15 @@ def logout(response: Response): } -def test_openapi_schema(): +def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text @@ -78,17 +84,15 @@ def test_security_jwt_auth(): assert response.cookies["refresh_token_cookie"] == response.json()["refresh_token"] -def test_security_jwt_logout(): +def test_security_jwt_logout(jwt_backend: Type[AbstractJWTBackend]): + client = create_example_client(jwt_backend) response = client.delete("/auth") assert response.status_code == 200, response.text assert "access_token_cookie" in response.headers["set-cookie"] assert 'access_token_cookie=""; Max-Age=-1;' in response.headers["set-cookie"] assert "refresh_token_cookie" in response.headers["set-cookie"] - assert ( - 'refresh_token_cookie=""; HttpOnly; Max-Age=-1' - in response.headers["set-cookie"] - ) + assert 'refresh_token_cookie=""; HttpOnly; Max-Age=-1' in response.headers["set-cookie"] # assert "access_token_cookie" not in response.cookies # assert response.cookies["access_token_cookie"].max_age == -1 # assert "refresh_token_cookie" not in response.cookies