diff --git a/authx/_internal/_error.py b/authx/_internal/_error.py index 84145b63..616f4f13 100644 --- a/authx/_internal/_error.py +++ b/authx/_internal/_error.py @@ -9,21 +9,18 @@ class _ErrorHandler: """Base Handler for FastAPI handling AuthX exceptions""" - def __init__(self) -> None: - """Base Handler for FastAPI handling AuthX exceptions""" - - self.MSG_DEFAULT = "AuthX Error" - self.MSG_TOKEN_ERROR = "Token Error" - self.MSG_MISSING_TOKEN_ERROR = "Missing JWT in request" - self.MSG_MISSING_CSRF_ERROR = "Missing CSRF double submit token in request" - self.MSG_TOKEN_TYPE_ERROR = "Bad token type" - self.MSG_REVOKED_TOKEN_ERROR = "Invalid token" - self.MSG_TOKEN_REQUIRED_ERROR = "Token required" - self.MSG_FRESH_TOKEN_REQUIRED_ERROR = "Fresh token required" - self.MSG_ACCESS_TOKEN_REQUIRED_ERROR = "Access token required" - self.MSG_REFRESH_TOKEN_REQUIRED_ERROR = "Refresh token required" - self.MSG_CSRF_ERROR = "CSRF double submit does not match" - self.MSG_DECODE_JWT_ERROR = "Invalid Token" + MSG_DEFAULT = "AuthX Error" + MSG_TOKEN_ERROR = "Token Error" + MSG_MISSING_TOKEN_ERROR = "Missing JWT in request" + MSG_MISSING_CSRF_ERROR = "Missing CSRF double submit token in request" + MSG_TOKEN_TYPE_ERROR = "Bad token type" + MSG_REVOKED_TOKEN_ERROR = "Invalid token" + MSG_TOKEN_REQUIRED_ERROR = "Token required" + MSG_FRESH_TOKEN_REQUIRED_ERROR = "Fresh token required" + MSG_ACCESS_TOKEN_REQUIRED_ERROR = "Access token required" + MSG_REFRESH_TOKEN_REQUIRED_ERROR = "Refresh token required" + MSG_CSRF_ERROR = "CSRF double submit does not match" + MSG_DECODE_JWT_ERROR = "Invalid Token" def _error_handler( self, diff --git a/authx/main.py b/authx/main.py index aa270f9b..3d44f4aa 100644 --- a/authx/main.py +++ b/authx/main.py @@ -272,11 +272,14 @@ async def _get_token_from_request( return None raise e - async def get_access_token_from_request(self, request: Request) -> RequestToken: + async def get_access_token_from_request( + self, request: Request, locations: Optional[TokenLocations] = None + ) -> RequestToken: """Dependency to retrieve access token from request Args: request (Request): Request to retrieve access token from + locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None. Raises: MissingTokenError: When no `access` token is available in request @@ -284,13 +287,18 @@ async def get_access_token_from_request(self, request: Request) -> RequestToken: Returns: RequestToken: Request Token instance for `access` token type """ - return await self._get_token_from_request(request, optional=False) + return await self._get_token_from_request( + request, optional=False, locations=locations + ) - async def get_refresh_token_from_request(self, request: Request) -> RequestToken: + async def get_refresh_token_from_request( + self, request: Request, locations: Optional[TokenLocations] = None + ) -> RequestToken: """Dependency to retrieve refresh token from request Args: request (Request): Request to retrieve refresh token from + locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None. Raises: MissingTokenError: When no `refresh` token is available in request @@ -298,7 +306,9 @@ async def get_refresh_token_from_request(self, request: Request) -> RequestToken Returns: RequestToken: Request Token instance for `refresh` token type """ - return await self._get_token_from_request(request, refresh=True, optional=False) + return await self._get_token_from_request( + request, refresh=True, optional=False, locations=locations + ) async def _auth_required( self, @@ -307,6 +317,7 @@ async def _auth_required( verify_type: bool = True, verify_fresh: bool = False, verify_csrf: Optional[bool] = None, + locations: Optional[TokenLocations] = None, ) -> TokenPayload: if type == "access": method = self.get_access_token_from_request @@ -321,6 +332,7 @@ async def _auth_required( request_token = await method( request=request, + locations=locations, ) if self.is_token_in_blocklist(request_token.token): @@ -515,6 +527,7 @@ def token_required( verify_type: bool = True, verify_fresh: bool = False, verify_csrf: Optional[bool] = None, + locations: Optional[TokenLocations] = None, ) -> Callable[[Request], TokenPayload]: """Dependency to enforce valid token availability in request @@ -523,6 +536,7 @@ def token_required( verify_type (bool, optional): Apply type verification. Defaults to True. verify_fresh (bool, optional): Require token freshness. Defaults to False. verify_csrf (Optional[bool], optional): Enable CSRF verification. Defaults to None. + locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None. Returns: Callable[[Request], TokenPayload]: Dependency for Valid token Payload retrieval @@ -535,6 +549,7 @@ async def _auth_required(request: Request): verify_csrf=verify_csrf, verify_type=verify_type, verify_fresh=verify_fresh, + locations=locations, ) return _auth_required