Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance BlacklistMixin with Generic Type for Accurate Type Inference #768

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, TypeVar
from uuid import uuid4

from django.conf import settings
Expand All @@ -22,6 +22,8 @@
if TYPE_CHECKING:
from .backends import TokenBackend

T = TypeVar("T", bound="Token")

Comment on lines +25 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bounded context. Looks good to me.

AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)


Expand Down Expand Up @@ -229,7 +231,7 @@ def get_token_backend(self) -> "TokenBackend":
return self.token_backend


class BlacklistMixin:
class BlacklistMixin(Generic[T]):
"""
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
used, tokens created from `BlacklistMixin` subclasses will insert
Expand Down Expand Up @@ -276,7 +278,7 @@ def blacklist(self) -> BlacklistedToken:
return BlacklistedToken.objects.get_or_create(token=token)

@classmethod
def for_user(cls, user: AuthUser) -> Token:
def for_user(cls: Type[T], user: AuthUser) -> T:
"""
Adds this token to the outstanding token list.
"""
Expand All @@ -296,7 +298,7 @@ def for_user(cls, user: AuthUser) -> Token:
return token


class SlidingToken(BlacklistMixin, Token):
class SlidingToken(BlacklistMixin["SlidingToken"], Token):
token_type = "sliding"
lifetime = api_settings.SLIDING_TOKEN_LIFETIME

Expand All @@ -317,7 +319,7 @@ class AccessToken(Token):
lifetime = api_settings.ACCESS_TOKEN_LIFETIME


class RefreshToken(BlacklistMixin, Token):
class RefreshToken(BlacklistMixin["RefreshToken"], Token):
token_type = "refresh"
lifetime = api_settings.REFRESH_TOKEN_LIFETIME
no_copy_claims = (
Expand Down
Loading