Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specific "token expired" exceptions #830

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms

from .exceptions import TokenBackendError
from jwt import (
ExpiredSignatureError,
InvalidAlgorithmError,
InvalidTokenError,
algorithms,
)

from .exceptions import TokenBackendError, TokenBackendExpiredToken
from .tokens import Token
from .utils import format_lazy

Expand Down Expand Up @@ -101,7 +106,7 @@ def get_verifying_key(self, token: Token) -> Optional[str]:
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
raise TokenBackendError(_("Token is invalid")) from ex

return self.verifying_key

Expand Down Expand Up @@ -150,5 +155,7 @@ def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]:
)
except InvalidAlgorithmError as ex:
raise TokenBackendError(_("Invalid algorithm specified")) from ex
except ExpiredSignatureError as ex:
raise TokenBackendExpiredToken(_("Token is expired")) from ex
except InvalidTokenError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
raise TokenBackendError(_("Token is invalid")) from ex
8 changes: 8 additions & 0 deletions rest_framework_simplejwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ class TokenError(Exception):
pass


class ExpiredTokenError(TokenError):
pass


class TokenBackendError(Exception):
pass


class TokenBackendExpiredToken(TokenBackendError):
pass


class DetailDictMixin:
default_detail: str
default_code: str
Expand Down
11 changes: 9 additions & 2 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from django.utils.module_loading import import_string
from django.utils.translation import gettext_lazy as _

from .exceptions import TokenBackendError, TokenError
from .exceptions import (
ExpiredTokenError,
TokenBackendError,
TokenBackendExpiredToken,
TokenError,
)
from .models import TokenUser
from .settings import api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
Expand Down Expand Up @@ -56,8 +61,10 @@ def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None
# Decode token
try:
self.payload = token_backend.decode(token, verify=verify)
except TokenBackendExpiredToken:
raise ExpiredTokenError(_("Token is expired"))
except TokenBackendError:
raise TokenError(_("Token is invalid or expired"))
raise TokenError(_("Token is invalid"))

if verify:
self.verify()
Expand Down
11 changes: 6 additions & 5 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from jwt import algorithms

from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.exceptions import (
TokenBackendError,
TokenBackendExpiredToken,
)
from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc
from tests.keys import (
ES256_PRIVATE_KEY,
Expand Down Expand Up @@ -191,7 +194,7 @@ def test_decode_with_expiry(self):
self.payload, backend.signing_key, algorithm=backend.algorithm
)

with self.assertRaises(TokenBackendError):
with self.assertRaises(TokenBackendExpiredToken):
backend.decode(expired_token)

def test_decode_with_invalid_sig(self):
Expand Down Expand Up @@ -346,9 +349,7 @@ def test_decode_jwk_missing_key_raises_tokenbackenderror(self):
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
)

with self.assertRaisesRegex(
TokenBackendError, "Token is invalid or expired"
):
with self.assertRaisesRegex(TokenBackendError, "Token is invalid"):
jwk_token_backend.decode(token)

def test_decode_when_algorithm_not_available(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_it_should_not_validate_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_no_refresh_exp_claim(self):
token = SlidingToken()
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_not_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from django.test import TestCase
from jose import jwt

from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError
from rest_framework_simplejwt.exceptions import (
ExpiredTokenError,
TokenBackendError,
TokenError,
)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import token_backend
from rest_framework_simplejwt.tokens import (
Expand Down Expand Up @@ -157,7 +161,7 @@ def test_init_expired_token_given(self):
t = MyToken()
t.set_exp(lifetime=-timedelta(seconds=1))

with self.assertRaises(TokenError):
with self.assertRaises(ExpiredTokenError):
MyToken(str(t))

def test_init_no_type_token_given(self):
Expand Down
Loading