diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index b207ef27c..8dff4e01f 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, TypeVar from uuid import uuid4 from django.conf import settings @@ -22,6 +22,8 @@ if TYPE_CHECKING: from .backends import TokenBackend +T = TypeVar("T", bound="Token") + AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) @@ -229,7 +231,7 @@ def get_token_backend(self) -> "TokenBackend": return self.token_backend -class BlacklistMixin: +class BlacklistMixin(Generic[T]): """ If the `rest_framework_simplejwt.token_blacklist` app was configured to be used, tokens created from `BlacklistMixin` subclasses will insert @@ -276,7 +278,7 @@ def blacklist(self) -> BlacklistedToken: return BlacklistedToken.objects.get_or_create(token=token) @classmethod - def for_user(cls, user: AuthUser) -> Token: + def for_user(cls: Type[T], user: AuthUser) -> T: """ Adds this token to the outstanding token list. """ @@ -296,7 +298,7 @@ def for_user(cls, user: AuthUser) -> Token: return token -class SlidingToken(BlacklistMixin, Token): +class SlidingToken(BlacklistMixin["SlidingToken"], Token): token_type = "sliding" lifetime = api_settings.SLIDING_TOKEN_LIFETIME @@ -317,7 +319,7 @@ class AccessToken(Token): lifetime = api_settings.ACCESS_TOKEN_LIFETIME -class RefreshToken(BlacklistMixin, Token): +class RefreshToken(BlacklistMixin["RefreshToken"], Token): token_type = "refresh" lifetime = api_settings.REFRESH_TOKEN_LIFETIME no_copy_claims = (