Skip to content

Commit

Permalink
♻️ Refactor error message constants (#565)
Browse files Browse the repository at this point in the history
* ♻️ Refactor error message constants

* ✨ Add optional parameter for token locations
  • Loading branch information
yezz123 committed Apr 4, 2024
1 parent 31e6573 commit 083a579
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
27 changes: 12 additions & 15 deletions authx/_internal/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
class _ErrorHandler:
"""Base Handler for FastAPI handling AuthX exceptions"""

def __init__(self) -> None:
"""Base Handler for FastAPI handling AuthX exceptions"""

self.MSG_DEFAULT = "AuthX Error"
self.MSG_TOKEN_ERROR = "Token Error"
self.MSG_MISSING_TOKEN_ERROR = "Missing JWT in request"
self.MSG_MISSING_CSRF_ERROR = "Missing CSRF double submit token in request"
self.MSG_TOKEN_TYPE_ERROR = "Bad token type"
self.MSG_REVOKED_TOKEN_ERROR = "Invalid token"
self.MSG_TOKEN_REQUIRED_ERROR = "Token required"
self.MSG_FRESH_TOKEN_REQUIRED_ERROR = "Fresh token required"
self.MSG_ACCESS_TOKEN_REQUIRED_ERROR = "Access token required"
self.MSG_REFRESH_TOKEN_REQUIRED_ERROR = "Refresh token required"
self.MSG_CSRF_ERROR = "CSRF double submit does not match"
self.MSG_DECODE_JWT_ERROR = "Invalid Token"
MSG_DEFAULT = "AuthX Error"
MSG_TOKEN_ERROR = "Token Error"
MSG_MISSING_TOKEN_ERROR = "Missing JWT in request"
MSG_MISSING_CSRF_ERROR = "Missing CSRF double submit token in request"
MSG_TOKEN_TYPE_ERROR = "Bad token type"
MSG_REVOKED_TOKEN_ERROR = "Invalid token"
MSG_TOKEN_REQUIRED_ERROR = "Token required"
MSG_FRESH_TOKEN_REQUIRED_ERROR = "Fresh token required"
MSG_ACCESS_TOKEN_REQUIRED_ERROR = "Access token required"
MSG_REFRESH_TOKEN_REQUIRED_ERROR = "Refresh token required"
MSG_CSRF_ERROR = "CSRF double submit does not match"
MSG_DECODE_JWT_ERROR = "Invalid Token"

def _error_handler(
self,
Expand Down
23 changes: 19 additions & 4 deletions authx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,33 +272,43 @@ async def _get_token_from_request(
return None
raise e

async def get_access_token_from_request(self, request: Request) -> RequestToken:
async def get_access_token_from_request(
self, request: Request, locations: Optional[TokenLocations] = None
) -> RequestToken:
"""Dependency to retrieve access token from request
Args:
request (Request): Request to retrieve access token from
locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None.
Raises:
MissingTokenError: When no `access` token is available in request
Returns:
RequestToken: Request Token instance for `access` token type
"""
return await self._get_token_from_request(request, optional=False)
return await self._get_token_from_request(
request, optional=False, locations=locations
)

async def get_refresh_token_from_request(self, request: Request) -> RequestToken:
async def get_refresh_token_from_request(
self, request: Request, locations: Optional[TokenLocations] = None
) -> RequestToken:
"""Dependency to retrieve refresh token from request
Args:
request (Request): Request to retrieve refresh token from
locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None.
Raises:
MissingTokenError: When no `refresh` token is available in request
Returns:
RequestToken: Request Token instance for `refresh` token type
"""
return await self._get_token_from_request(request, refresh=True, optional=False)
return await self._get_token_from_request(
request, refresh=True, optional=False, locations=locations
)

async def _auth_required(
self,
Expand All @@ -307,6 +317,7 @@ async def _auth_required(
verify_type: bool = True,
verify_fresh: bool = False,
verify_csrf: Optional[bool] = None,
locations: Optional[TokenLocations] = None,
) -> TokenPayload:
if type == "access":
method = self.get_access_token_from_request
Expand All @@ -321,6 +332,7 @@ async def _auth_required(

request_token = await method(
request=request,
locations=locations,
)

if self.is_token_in_blocklist(request_token.token):
Expand Down Expand Up @@ -515,6 +527,7 @@ def token_required(
verify_type: bool = True,
verify_fresh: bool = False,
verify_csrf: Optional[bool] = None,
locations: Optional[TokenLocations] = None,
) -> Callable[[Request], TokenPayload]:
"""Dependency to enforce valid token availability in request
Expand All @@ -523,6 +536,7 @@ def token_required(
verify_type (bool, optional): Apply type verification. Defaults to True.
verify_fresh (bool, optional): Require token freshness. Defaults to False.
verify_csrf (Optional[bool], optional): Enable CSRF verification. Defaults to None.
locations (Optional[TokenLocations], optional): Locations to retrieve token from. Defaults to None.
Returns:
Callable[[Request], TokenPayload]: Dependency for Valid token Payload retrieval
Expand All @@ -535,6 +549,7 @@ async def _auth_required(request: Request):
verify_csrf=verify_csrf,
verify_type=verify_type,
verify_fresh=verify_fresh,
locations=locations,
)

return _auth_required
Expand Down

0 comments on commit 083a579

Please sign in to comment.