diff --git a/fastapi_jwt/__init__.py b/fastapi_jwt/__init__.py index 25585c5..4b3e65c 100644 --- a/fastapi_jwt/__init__.py +++ b/fastapi_jwt/__init__.py @@ -1 +1,2 @@ from .jwt import * # noqa: F401, F403 +from .jwt_backends import * # noqa: F401, F403 \ No newline at end of file diff --git a/fastapi_jwt/jwt.py b/fastapi_jwt/jwt.py index d9041ed..d598dd3 100644 --- a/fastapi_jwt/jwt.py +++ b/fastapi_jwt/jwt.py @@ -8,11 +8,21 @@ from fastapi.responses import Response from fastapi.security import APIKeyCookie, HTTPBearer from starlette.status import HTTP_401_UNAUTHORIZED +from .jwt_backends import AuthlibJWTBackend, PythonJoseJWTBackend -try: - from jose import jwt -except ImportError: # pragma: nocover - jwt = None # type: ignore[assignment] + +DEFAULT_JWT_BACKEND = None + + +def define_default_jwt_backend(cls): + global DEFAULT_JWT_BACKEND + DEFAULT_JWT_BACKEND = cls + + +if AuthlibJWTBackend is not None: + define_default_jwt_backend(AuthlibJWTBackend) +elif PythonJoseJWTBackend is not None: + define_default_jwt_backend(PythonJoseJWTBackend) def utcnow(): @@ -72,28 +82,26 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): - assert jwt is not None, "python-jose must be installed to use JwtAuth" + self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm) + self.secret_key = secret_key if places: assert places.issubset( {"header", "cookie"} ), "only 'header'/'cookie' are supported" - algorithm = algorithm.upper() - assert ( - hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined] - ), f"{algorithm} algorithm is not supported by python-jose library" - - self.secret_key = secret_key self.places = places or {"header"} self.auto_error = auto_error - self.algorithm = algorithm self.access_expires_delta = access_expires_delta or timedelta(minutes=15) self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31) + @property + def algorithm(self): + return self.jwt_backend.algorithm + @classmethod def from_other( cls, @@ -112,30 +120,6 @@ def from_other( refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta, ) - def _decode(self, token: str) -> Optional[Dict[str, Any]]: - try: - payload: Dict[str, Any] = jwt.decode( - token, - self.secret_key, - algorithms=[self.algorithm], - options={"leeway": 10}, - ) - return payload - except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined] - if self.auto_error: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}" - ) - else: - return None - except jwt.JWTError as e: # type: ignore[attr-defined] - if self.auto_error: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}" - ) - else: - return None - def _generate_payload( self, subject: Dict[str, Any], @@ -144,7 +128,6 @@ def _generate_payload( token_type: str, ) -> Dict[str, Any]: now = utcnow() - return { "subject": subject.copy(), # main subject "type": token_type, # 'access' or 'refresh' token @@ -172,8 +155,7 @@ async def _get_payload( return None # Try to decode jwt token. auto_error on error - payload = self._decode(token) - return payload + return self.jwt_backend.decode(token, self.secret_key, self.auto_error) def create_access_token( self, @@ -186,11 +168,7 @@ def create_access_token( to_encode = self._generate_payload( subject, expires_delta, unique_identifier, "access" ) - - jwt_encoded: str = jwt.encode( - to_encode, self.secret_key, algorithm=self.algorithm - ) - return jwt_encoded + return self.jwt_backend.encode(to_encode, self.secret_key) def create_refresh_token( self, @@ -203,11 +181,7 @@ def create_refresh_token( to_encode = self._generate_payload( subject, expires_delta, unique_identifier, "refresh" ) - - jwt_encoded: str = jwt.encode( - to_encode, self.secret_key, algorithm=self.algorithm - ) - return jwt_encoded + return self.jwt_backend.encode(to_encode, self.secret_key) @staticmethod def set_access_cookie( @@ -261,7 +235,7 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -293,7 +267,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -317,7 +291,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -342,7 +316,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -372,7 +346,7 @@ def __init__( secret_key: str, places: Optional[Set[str]] = None, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -414,7 +388,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -438,7 +412,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): @@ -463,7 +437,7 @@ def __init__( self, secret_key: str, auto_error: bool = True, - algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] + algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, ): diff --git a/fastapi_jwt/jwt_backends/__init__.py b/fastapi_jwt/jwt_backends/__init__.py new file mode 100644 index 0000000..310f02c --- /dev/null +++ b/fastapi_jwt/jwt_backends/__init__.py @@ -0,0 +1,9 @@ +try: + from .authlib_backend import AuthlibJWTBackend +except ImportError: + AuthlibJWTBackend = None + +try: + from .python_jose_backend import PythonJoseJWTBackend +except ImportError: + PythonJoseJWTBackend = None diff --git a/fastapi_jwt/jwt_backends/abstract_backend.py b/fastapi_jwt/jwt_backends/abstract_backend.py new file mode 100644 index 0000000..ca337ae --- /dev/null +++ b/fastapi_jwt/jwt_backends/abstract_backend.py @@ -0,0 +1,31 @@ +from abc import ABCMeta, abstractmethod, abstractproperty +from typing import Any, Dict, Optional, Self + + + +class AbstractJWTBackend(metaclass=ABCMeta): + + # simple "SingletonArgs" implementation to keep a JWTBackend per algorithm + _instances = {} + + def __new__(cls, algorithm) -> Self: + instance_key = (cls, algorithm) + if instance_key not in cls._instances: + cls._instances[instance_key] = super(AbstractJWTBackend, cls).__new__(cls) + return cls._instances[instance_key] + + @abstractmethod + def __init__(self, algorithm) -> None: + pass + + @abstractproperty + def default_algorithm(self) -> str: + pass + + @abstractmethod + def encode(self, to_encode, secret_key) -> str: + pass + + @abstractmethod + def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]: + pass diff --git a/fastapi_jwt/jwt_backends/authlib_backend.py b/fastapi_jwt/jwt_backends/authlib_backend.py new file mode 100644 index 0000000..95b051f --- /dev/null +++ b/fastapi_jwt/jwt_backends/authlib_backend.py @@ -0,0 +1,51 @@ +from fastapi import HTTPException +from typing import Any, Dict, Optional +from starlette.status import HTTP_401_UNAUTHORIZED + +from authlib.jose import JsonWebSignature, JsonWebToken +from authlib.jose.errors import ( + DecodeError, ExpiredTokenError, InvalidClaimError, InvalidTokenError +) +from .abstract_backend import AbstractJWTBackend + + +class AuthlibJWTBackend(AbstractJWTBackend): + + def __init__(self, algorithm) -> None: + self.algorithm = algorithm if algorithm is not None else self.default_algorithm + # from https://github.com/lepture/authlib/blob/85f9ff/authlib/jose/__init__.py#L45 + valid_algorithms = list(JsonWebSignature.ALGORITHMS_REGISTRY.keys()) + assert ( + self.algorithm in valid_algorithms + ), f"{self.algorithm} algorithm is not supported by authlib" + self.jwt = JsonWebToken(algorithms=[self.algorithm]) + + @property + def default_algorithm(self) -> str: + return "HS256" + + def encode(self, to_encode, secret_key) -> str: + token = self.jwt.encode(header={"alg": self.algorithm}, payload=to_encode, key=secret_key) + return token.decode() # convert to string + + def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]: + try: + payload = self.jwt.decode(token, secret_key) + payload.validate(leeway=10) + return dict(payload) + except ExpiredTokenError as e: # type: ignore[attr-defined] + if auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}" + ) + else: + return None + except (InvalidClaimError, + InvalidTokenError, + DecodeError) as e: # type: ignore[attr-defined] + if auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}" + ) + else: + return None \ No newline at end of file diff --git a/fastapi_jwt/jwt_backends/python_jose_backend.py b/fastapi_jwt/jwt_backends/python_jose_backend.py new file mode 100644 index 0000000..21f191c --- /dev/null +++ b/fastapi_jwt/jwt_backends/python_jose_backend.py @@ -0,0 +1,47 @@ +from fastapi import HTTPException +from typing import Any, Dict, Optional +from starlette.status import HTTP_401_UNAUTHORIZED + +from jose import jwt + +from .abstract_backend import AbstractJWTBackend + + +class PythonJoseJWTBackend(AbstractJWTBackend): + + def __init__(self, algorithm) -> None: + self.algorithm = algorithm if algorithm is not None else self.default_algorithm + assert ( + hasattr(jwt.ALGORITHMS, self.algorithm) is True # type: ignore[attr-defined] + ), f"{algorithm} algorithm is not supported by python-jose library" + + @property + def default_algorithm(self) -> str: + return jwt.ALGORITHMS.HS256 + + def encode(self, to_encode, secret_key) -> str: + return jwt.encode(to_encode, secret_key, algorithm=self.algorithm) + + def decode(self, token, secret_key, auto_error) -> Optional[Dict[str, Any]]: + try: + payload: Dict[str, Any] = jwt.decode( + token, + secret_key, + algorithms=[self.algorithm], + options={"leeway": 10}, + ) + return payload + except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined] + if auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}" + ) + else: + return None + except jwt.JWTError as e: # type: ignore[attr-defined] + if auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}" + ) + else: + return None \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 28e0bd1..03b47f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ dependencies = [ "fastapi >=0.50.0", - "python-jose[cryptography] >=3.3.0" ] @@ -37,7 +36,15 @@ documentation = "https://k4black.github.io/fastapi-jwt/" [project.optional-dependencies] +authlib = [ + "Authlib >=1.3.0" +] +python_jose = [ + "python-jose[cryptography] >=3.3.0" +] test = [ + "Authlib >=1.3.0", + "python-jose[cryptography] >=3.3.0", "httpx >=0.23.0,<1.0.0", "pytest >=7.0.0,<9.0.0", "pytest-cov >=4.0.0,<5.0.0", diff --git a/tests/mock_datetime_utils.py b/tests/mock_datetime_utils.py new file mode 100644 index 0000000..f1f4c7f --- /dev/null +++ b/tests/mock_datetime_utils.py @@ -0,0 +1,39 @@ +import datetime +import time + +from fastapi_jwt import AuthlibJWTBackend, PythonJoseJWTBackend + + +_time = time.time +_now = datetime.datetime.now +_utcnow = datetime.datetime.utcnow + + +def create_datetime_mock(**timedelta_kwargs): + + class _FakeDateTime(datetime.datetime): # pragma: no cover + @staticmethod + def now(**kwargs): + return _now() + datetime.timedelta(**timedelta_kwargs) + + @staticmethod + def utcnow(**kwargs): + return _utcnow() + datetime.timedelta(**timedelta_kwargs) + + return _FakeDateTime + + +def create_time_time_mock(**kwargs): + def _fake_time_time(): + return _time() + datetime.timedelta(**kwargs).total_seconds() + + return _fake_time_time + + +def mock_now_for_backend(mocker, jwt_backend, **kwargs): + if jwt_backend is AuthlibJWTBackend: + mocker.patch("authlib.jose.rfc7519.claims.time.time", create_time_time_mock(**kwargs)) + elif jwt_backend is PythonJoseJWTBackend: + mocker.patch("jose.jwt.datetime", create_datetime_mock(**kwargs)) + else: + raise Exception("Invalid Backend") \ No newline at end of file diff --git a/tests/test_security_jwt_bearer.py b/tests/test_security_jwt_bearer.py index 104917c..99e048a 100644 --- a/tests/test_security_jwt_bearer.py +++ b/tests/test_security_jwt_bearer.py @@ -1,40 +1,46 @@ +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessBearer(secret_key="secret_key") -refresh_security = JwtRefreshBearer(secret_key="secret_key") +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearer(secret_key="secret_key") + refresh_security = JwtRefreshBearer(secret_key="secret_key") -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + + + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -88,18 +94,24 @@ def read_current_user( } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_auth(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_bearer(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -109,27 +121,35 @@ def test_security_jwt_access_bearer(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.get( "/users/me", headers={"Authorization": "Bearer wrong_access_token"} ) assert response.status_code == 401, response.text -def test_security_jwt_access_bearer_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_access_bearer_incorrect_scheme_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} # assert response.json() == {"detail": "Invalid authentication credentials"} -def test_security_jwt_refresh_bearer(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -138,20 +158,26 @@ def test_security_jwt_refresh_bearer(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.post( "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} ) assert response.status_code == 401, response.text -def test_security_jwt_refresh_bearer_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh", headers={"Authorization": "Basic notreally"}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} diff --git a/tests/test_security_jwt_bearer_optional.py b/tests/test_security_jwt_bearer_optional.py index 5029ff6..a63d443 100644 --- a/tests/test_security_jwt_bearer_optional.py +++ b/tests/test_security_jwt_bearer_optional.py @@ -1,49 +1,56 @@ +from collections import namedtuple from typing import Optional +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshBearer(secret_key="secret_key", auto_error=False) +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshBearer(secret_key="secret_key", auto_error=False) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} + + + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -97,18 +104,24 @@ def read_current_user( } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_auth(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_bearer(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -118,7 +131,9 @@ def test_security_jwt_access_bearer(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.get( "/users/me", headers={"Authorization": "Bearer wrong_access_token"} ) @@ -126,19 +141,25 @@ def test_security_jwt_access_bearer_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_bearer_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_bearer_incorrect_scheme_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me", headers={"Authorization": "Basic notreally"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -147,7 +168,9 @@ def test_security_jwt_refresh_bearer(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.post( "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} ) @@ -155,13 +178,17 @@ def test_security_jwt_refresh_bearer_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh", headers={"Authorization": "Basic notreally"}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_jwt_cookie.py b/tests/test_security_jwt_cookie.py index d506eba..5dd5128 100644 --- a/tests/test_security_jwt_cookie.py +++ b/tests/test_security_jwt_cookie.py @@ -1,40 +1,46 @@ +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessCookie(secret_key="secret_key") -refresh_security = JwtRefreshCookie(secret_key="secret_key") +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessCookie(secret_key="secret_key") + refresh_security = JwtRefreshCookie(secret_key="secret_key") -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + + + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -96,18 +102,24 @@ def read_current_user( } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_auth(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get("/users/me", cookies={"access_token_cookie": access_token}) @@ -115,21 +127,27 @@ def test_security_jwt_access_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.get( "/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"} ) assert response.status_code == 401, response.text -def test_security_jwt_access_cookie_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) client.cookies.clear() response = client.get("/users/me", cookies={}) assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie(jwt_backend): + client = create_example_client(jwt_backend) client.cookies.clear() refresh_token = client.post("/auth").json()["refresh_token"] @@ -137,14 +155,18 @@ def test_security_jwt_refresh_cookie(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.post( "/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"} ) assert response.status_code == 401, response.text -def test_security_jwt_refresh_cookie_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) client.cookies.clear() response = client.post("/refresh", cookies={}) assert response.status_code == 401, response.text diff --git a/tests/test_security_jwt_cookie_optional.py b/tests/test_security_jwt_cookie_optional.py index 7cc2f51..2b6d93c 100644 --- a/tests/test_security_jwt_cookie_optional.py +++ b/tests/test_security_jwt_cookie_optional.py @@ -1,50 +1,56 @@ +import pytest from typing import Optional from fastapi import FastAPI, Security from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessCookie, JwtAuthorizationCredentials, JwtRefreshCookie +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessCookie(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshCookie(secret_key="secret_key", auto_error=False) +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessCookie(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshCookie(secret_key="secret_key", auto_error=False) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} + + + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -106,18 +112,24 @@ def read_current_user( } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_auth(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text -def test_security_jwt_access_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie(jwt_backend): + client = create_example_client(jwt_backend) client.cookies.clear() access_token = client.post("/auth").json()["access_token"] @@ -126,7 +138,9 @@ def test_security_jwt_access_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.get( "/users/me", cookies={"access_token_cookie": "wrong_access_token_cookie"} ) @@ -134,20 +148,26 @@ def test_security_jwt_access_cookie_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_cookie_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_cookie_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me", cookies={}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post("/refresh", cookies={"refresh_token_cookie": refresh_token}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) response = client.post( "/refresh", cookies={"refresh_token_cookie": "wrong_refresh_token_cookie"} ) @@ -155,7 +175,9 @@ def test_security_jwt_refresh_cookie_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_cookie_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh", cookies={}) assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_security_jwt_general.py b/tests/test_security_jwt_general.py index 62a160e..718be46 100644 --- a/tests/test_security_jwt_general.py +++ b/tests/test_security_jwt_general.py @@ -1,85 +1,69 @@ -import datetime from typing import Set from uuid import uuid4 +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from pytest_mock import MockerFixture from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend +from .mock_datetime_utils import mock_now_for_backend -app = FastAPI() -access_security = JwtAccessBearer(secret_key="secret_key") -refresh_security = JwtRefreshBearer.from_other(access_security) +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearer(secret_key="secret_key") + refresh_security = JwtRefreshBearer.from_other(access_security) + unique_identifiers_database: Set[str] = set() -unique_identifiers_database: Set[str] = set() + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) + access_token = access_security.create_access_token( + subject=subject, unique_identifier=unique_identifier + ) + refresh_token = access_security.create_refresh_token(subject=subject) - access_token = access_security.create_access_token( - subject=subject, unique_identifier=unique_identifier - ) - refresh_token = access_security.create_refresh_token(subject=subject) - - return {"access_token": access_token, "refresh_token": refresh_token} - - -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) - - access_token = refresh_security.create_access_token( - subject=credentials.subject, unique_identifier=unique_identifier, - ) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - - return {"access_token": access_token, "refresh_token": refresh_token} + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) + access_token = refresh_security.create_access_token( + subject=credentials.subject, unique_identifier=unique_identifier, + ) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) -@app.get("/auth/meta") -def get_token_meta( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"jti": credentials.jti} + return {"access_token": access_token, "refresh_token": refresh_token} -class _FakeDateTimeShort(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(minutes=3) + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(minutes=3) + @app.get("/auth/meta") + def get_token_meta( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"jti": credentials.jti} -class _FakeDateTimeLong(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(days=42) - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(days=42) + return TestClient(app), unique_identifiers_database -client = TestClient(app) - openapi_schema = { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -145,13 +129,17 @@ def utcnow(**kwargs): } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token(jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -161,7 +149,9 @@ def test_security_jwt_access_token(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_token_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_wrong(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.get( "/users/me", headers={"Authorization": "Bearer wrong_access_token"} ) @@ -175,7 +165,9 @@ def test_security_jwt_access_token_wrong(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_access_token_changed(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_changed(jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] access_token = access_token.split(".")[0] + ".wrong." + access_token.split(".")[-1] @@ -187,28 +179,28 @@ def test_security_jwt_access_token_changed(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_access_token_expiration(mocker: MockerFixture): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_expiration(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeShort) # 3 min left - + mock_now_for_backend(mocker, jwt_backend, minutes=3) # 3 min left response = client.get( "/users/me", headers={"Authorization": f"Bearer {access_token}"} ) assert response.status_code == 200, response.text - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left response = client.get( "/users/me", headers={"Authorization": f"Bearer {access_token}"} ) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith( - "Token time expired: Signature has expired" - ) + assert response.json()["detail"].startswith("Token time expired:") -def test_security_jwt_refresh_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token(jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -217,7 +209,9 @@ def test_security_jwt_refresh_token(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_token_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_wrong(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.post( "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} ) @@ -231,7 +225,9 @@ def test_security_jwt_refresh_token_wrong(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_refresh_token_using_access_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_using_access_token(jwt_backend): + client, _ = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token @@ -243,7 +239,9 @@ def test_security_jwt_refresh_token_using_access_token(): assert response.json()["detail"].startswith("Wrong token: 'type' is not 'refresh'") -def test_security_jwt_refresh_token_changed(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_changed(jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] refresh_token = ( @@ -257,21 +255,22 @@ def test_security_jwt_refresh_token_changed(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_refresh_token_expired(mocker: MockerFixture): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_expired(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left response = client.post( "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} ) assert response.status_code == 401, response.text - assert response.json()["detail"].startswith( - "Token time expired: Signature has expired" - ) + assert response.json()["detail"].startswith("Token time expired:") -def test_security_jwt_custom_jti(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_custom_jti(jwt_backend): + client, unique_identifiers_database = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( diff --git a/tests/test_security_jwt_general_optional.py b/tests/test_security_jwt_general_optional.py index ab32257..456f81b 100644 --- a/tests/test_security_jwt_general_optional.py +++ b/tests/test_security_jwt_general_optional.py @@ -1,93 +1,77 @@ -import datetime from typing import Optional, Set from uuid import uuid4 +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from pytest_mock import MockerFixture from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend +from .mock_datetime_utils import mock_now_for_backend -app = FastAPI() -access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) -refresh_security = JwtRefreshBearer.from_other(access_security) +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearer(secret_key="secret_key", auto_error=False) + refresh_security = JwtRefreshBearer.from_other(access_security) + unique_identifiers_database: Set[str] = set() -unique_identifiers_database: Set[str] = set() + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) - - access_token = access_security.create_access_token( - subject=subject, unique_identifier=unique_identifier - ) - refresh_token = access_security.create_refresh_token(subject=subject) - - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token( + subject=subject, unique_identifier=unique_identifier + ) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh( - credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), -): - if credentials is None: - return {"msg": "Create an account first"} - - unique_identifier = str(uuid4()) - unique_identifiers_database.add(unique_identifier) - - access_token = refresh_security.create_access_token( - subject=credentials.subject, unique_identifier=unique_identifier, - ) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh( + credentials: Optional[JwtAuthorizationCredentials] = Security(refresh_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + unique_identifier = str(uuid4()) + unique_identifiers_database.add(unique_identifier) -@app.get("/users/me") -def read_current_user( - credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"username": credentials["username"], "role": credentials["role"]} + access_token = refresh_security.create_access_token( + subject=credentials.subject, unique_identifier=unique_identifier, + ) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/auth/meta") -def get_token_meta( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - if credentials is None: - return {"msg": "Create an account first"} - return {"jti": credentials.jti} + @app.get("/users/me") + def read_current_user( + credentials: Optional[JwtAuthorizationCredentials] = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"username": credentials["username"], "role": credentials["role"]} -class _FakeDateTimeShort(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(minutes=3) - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(minutes=3) + @app.get("/auth/meta") + def get_token_meta( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + if credentials is None: + return {"msg": "Create an account first"} + return {"jti": credentials.jti} -class _FakeDateTimeLong(datetime.datetime): # pragma: no cover - @staticmethod - def now(**kwargs): - return datetime.datetime.now() + datetime.timedelta(days=42) + return TestClient(app), unique_identifiers_database - @staticmethod - def utcnow(**kwargs): - return datetime.datetime.utcnow() + datetime.timedelta(days=42) - - -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -154,13 +138,17 @@ def utcnow(**kwargs): } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token(jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -170,7 +158,9 @@ def test_security_jwt_access_token(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_token_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_wrong(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.get( "/users/me", headers={"Authorization": "Bearer wrong_access_token"} ) @@ -184,7 +174,9 @@ def test_security_jwt_access_token_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_token_changed(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_changed(jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] access_token = access_token.split(".")[0] + ".wrong." + access_token.split(".")[-1] @@ -196,19 +188,19 @@ def test_security_jwt_access_token_changed(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_access_token_expiration(mocker: MockerFixture): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_token_expiration(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeShort) # 3 min left - + mock_now_for_backend(mocker, jwt_backend, minutes=3) # 3 min left response = client.get( "/users/me", headers={"Authorization": f"Bearer {access_token}"} ) assert response.status_code == 200, response.text assert response.json() == {"username": "username", "role": "user"} - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left response = client.get( "/users/me", headers={"Authorization": f"Bearer {access_token}"} ) @@ -216,7 +208,9 @@ def test_security_jwt_access_token_expiration(mocker: MockerFixture): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token(jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -226,7 +220,9 @@ def test_security_jwt_refresh_token(): assert "msg" not in response.json() -def test_security_jwt_refresh_token_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_wrong(jwt_backend): + client, _ = create_example_client(jwt_backend) response = client.post( "/refresh", headers={"Authorization": "Bearer wrong_refresh_token"} ) @@ -240,7 +236,9 @@ def test_security_jwt_refresh_token_wrong(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_using_access_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_using_access_token(jwt_backend): + client, _ = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token @@ -252,7 +250,9 @@ def test_security_jwt_refresh_token_using_access_token(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_changed(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_changed(jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] refresh_token = ( @@ -266,11 +266,12 @@ def test_security_jwt_refresh_token_changed(): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_refresh_token_expired(mocker: MockerFixture): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_token_expired(mocker: MockerFixture, jwt_backend): + client, _ = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] - mocker.patch("jose.jwt.datetime", _FakeDateTimeLong) # 42 days left - + mock_now_for_backend(mocker, jwt_backend, days=42) # 42 days left response = client.post( "/refresh", headers={"Authorization": f"Bearer {refresh_token}"} ) @@ -278,7 +279,9 @@ def test_security_jwt_refresh_token_expired(mocker: MockerFixture): assert response.json() == {"msg": "Create an account first"} -def test_security_jwt_custom_jti(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_custom_jti(jwt_backend): + client, unique_identifiers_database = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( diff --git a/tests/test_security_jwt_multiple_places.py b/tests/test_security_jwt_multiple_places.py index af60bb8..1e0d21a 100644 --- a/tests/test_security_jwt_multiple_places.py +++ b/tests/test_security_jwt_multiple_places.py @@ -1,40 +1,46 @@ +import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessBearerCookie, JwtAuthorizationCredentials, JwtRefreshBearerCookie +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessBearerCookie(secret_key="secret_key") -refresh_security = JwtRefreshBearerCookie(secret_key="secret_key") +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessBearerCookie(secret_key="secret_key") + refresh_security = JwtRefreshBearerCookie(secret_key="secret_key") -@app.post("/auth") -def auth(): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(): + subject = {"username": "username", "role": "user"} - return {"access_token": access_token, "refresh_token": refresh_token} + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.post("/refresh") -def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): - access_token = refresh_security.create_access_token(subject=credentials.subject) - refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) - return {"access_token": access_token, "refresh_token": refresh_token} + @app.post("/refresh") + def refresh(credentials: JwtAuthorizationCredentials = Security(refresh_security)): + access_token = refresh_security.create_access_token(subject=credentials.subject) + refresh_token = refresh_security.create_refresh_token(subject=credentials.subject) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.get("/users/me") -def read_current_user( - credentials: JwtAuthorizationCredentials = Security(access_security), -): - return {"username": credentials["username"], "role": credentials["role"]} + @app.get("/users/me") + def read_current_user( + credentials: JwtAuthorizationCredentials = Security(access_security), + ): + return {"username": credentials["username"], "role": credentials["role"]} + + + return TestClient(app) -client = TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -98,13 +104,17 @@ def read_current_user( } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_access_both_correct(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_both_correct(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -116,7 +126,9 @@ def test_security_jwt_access_both_correct(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_only_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_only_cookie(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get("/users/me", cookies={"access_token_cookie": access_token}) @@ -124,7 +136,9 @@ def test_security_jwt_access_only_cookie(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_only_bearer(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_only_bearer(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -134,7 +148,9 @@ def test_security_jwt_access_only_bearer(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_bearer_wrong_cookie_correct(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_wrong_cookie_correct(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -146,7 +162,9 @@ def test_security_jwt_access_bearer_wrong_cookie_correct(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_access_bearer_correct_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_bearer_correct_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) access_token = client.post("/auth").json()["access_token"] response = client.get( @@ -158,20 +176,26 @@ def test_security_jwt_access_bearer_correct_cookie_wrong(): assert response.json() == {"username": "username", "role": "user"} -def test_security_jwt_access_both_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_access_both_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/users/me") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} -def test_security_jwt_refresh_only_cookie(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_only_cookie(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post("/refresh", cookies={"refresh_token_cookie": refresh_token}) assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_correct_cookie_wrong(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_correct_cookie_wrong(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -182,7 +206,9 @@ def test_security_jwt_refresh_bearer_correct_cookie_wrong(): assert response.status_code == 200, response.text -def test_security_jwt_refresh_bearer_wrong_cookie_correct(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_bearer_wrong_cookie_correct(jwt_backend): + client = create_example_client(jwt_backend) refresh_token = client.post("/auth").json()["refresh_token"] response = client.post( @@ -194,7 +220,9 @@ def test_security_jwt_refresh_bearer_wrong_cookie_correct(): assert response.json()["detail"].startswith("Wrong token:") -def test_security_jwt_refresh_cookie_wrong_using_access_token(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_cookie_wrong_using_access_token(jwt_backend): + client = create_example_client(jwt_backend) tokens = client.post("/auth").json() access_token, refresh_token = tokens["access_token"], tokens["refresh_token"] assert access_token != refresh_token @@ -204,7 +232,9 @@ def test_security_jwt_refresh_cookie_wrong_using_access_token(): assert response.json()["detail"].startswith("Wrong token: 'type' is not 'refresh'") -def test_security_jwt_refresh_both_no_credentials(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_refresh_both_no_credentials(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/refresh") assert response.status_code == 401, response.text assert response.json() == {"detail": "Credentials are not provided"} diff --git a/tests/test_security_jwt_set_cookie.py b/tests/test_security_jwt_set_cookie.py index 6057cbc..91849f2 100644 --- a/tests/test_security_jwt_set_cookie.py +++ b/tests/test_security_jwt_set_cookie.py @@ -1,36 +1,41 @@ +import pytest from fastapi import FastAPI, Response from fastapi.testclient import TestClient from fastapi_jwt import JwtAccessCookie, JwtRefreshCookie +from fastapi_jwt.jwt import AuthlibJWTBackend, PythonJoseJWTBackend, define_default_jwt_backend -app = FastAPI() -access_security = JwtAccessCookie(secret_key="secret_key") -refresh_security = JwtRefreshCookie(secret_key="secret_key") +def create_example_client(jwt_backend): + define_default_jwt_backend(jwt_backend) + app = FastAPI() + access_security = JwtAccessCookie(secret_key="secret_key") + refresh_security = JwtRefreshCookie(secret_key="secret_key") -@app.post("/auth") -def auth(response: Response): - subject = {"username": "username", "role": "user"} - access_token = access_security.create_access_token(subject=subject) - refresh_token = access_security.create_refresh_token(subject=subject) + @app.post("/auth") + def auth(response: Response): + subject = {"username": "username", "role": "user"} - access_security.set_access_cookie(response, access_token) - refresh_security.set_refresh_cookie(response, refresh_token) + access_token = access_security.create_access_token(subject=subject) + refresh_token = access_security.create_refresh_token(subject=subject) - return {"access_token": access_token, "refresh_token": refresh_token} + access_security.set_access_cookie(response, access_token) + refresh_security.set_refresh_cookie(response, refresh_token) + return {"access_token": access_token, "refresh_token": refresh_token} -@app.delete("/auth") -def logout(response: Response): - access_security.unset_access_cookie(response) - refresh_security.unset_refresh_cookie(response) - return {"msg": "Successful logout"} + @app.delete("/auth") + def logout(response: Response): + access_security.unset_access_cookie(response) + refresh_security.unset_refresh_cookie(response) + return {"msg": "Successful logout"} -client = TestClient(app) + + return TestClient(app) openapi_schema = { "openapi": "3.1.0", @@ -62,13 +67,17 @@ def logout(response: Response): } -def test_openapi_schema(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_openapi_schema(jwt_backend): + client = create_example_client(jwt_backend) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema -def test_security_jwt_auth(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_auth(jwt_backend): + client = create_example_client(jwt_backend) response = client.post("/auth") assert response.status_code == 200, response.text @@ -78,7 +87,9 @@ def test_security_jwt_auth(): assert response.cookies["refresh_token_cookie"] == response.json()["refresh_token"] -def test_security_jwt_logout(): +@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend]) +def test_security_jwt_logout(jwt_backend): + client = create_example_client(jwt_backend) response = client.delete("/auth") assert response.status_code == 200, response.text