Skip to content

Commit

Permalink
Improve testing (#688)
Browse files Browse the repository at this point in the history
* Support `override_api_settings` as decorator

* Update test_authentication

* black formatting  test_authentication

* Use drf status instead of literal status

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_integration

* Update test_serializers

* Update test_integration

* Update test_token_blacklist

* Update test_tokens

* Update test_views

* add `setUpTestData` to `TestToken`

* fix typo `self` should be `cls`

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kiraware and pre-commit-ci[bot] authored Jun 21, 2023
1 parent c65036c commit d2cd59d
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 155 deletions.
98 changes: 51 additions & 47 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,19 @@ def test_get_header(self):
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work with the x_access_token
with override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN"):
# Should pull correct header off request when using X_ACCESS_TOKEN
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work for unicode headers when using
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
)
self.assertEqual(self.backend.get_header(request), self.fake_header)
@override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN")
def test_get_header_x_access_token(self):
# Should pull correct header off request when using X_ACCESS_TOKEN
request = self.factory.get("/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work for unicode headers when using
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

def test_get_raw_token(self):
# Should return None if header lacks correct type keyword
with override_api_settings(AUTH_HEADER_TYPES="JWT"):
reload(authentication)
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
reload(authentication)

# Should return None if an empty AUTHORIZATION header is sent
Expand All @@ -74,14 +68,21 @@ def test_get_raw_token(self):
# Otherwise, should return unvalidated token in header
self.assertEqual(self.backend.get_raw_token(self.fake_header), self.fake_token)

@override_api_settings(AUTH_HEADER_TYPES="JWT")
def test_get_raw_token_incorrect_header_keyword(self):
# Should return None if header lacks correct type keyword
# AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
reload(authentication)
self.assertIsNone(self.backend.get_raw_token(self.fake_header))

@override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer"))
def test_get_raw_token_multi_header_keyword(self):
# Should return token if header has one of many valid token types
with override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")):
reload(authentication)
self.assertEqual(
self.backend.get_raw_token(self.fake_header),
self.fake_token,
)
reload(authentication)
self.assertEqual(
self.backend.get_raw_token(self.fake_header),
self.fake_token,
)

def test_get_validated_token(self):
# Should raise InvalidToken if token not valid
Expand All @@ -96,36 +97,39 @@ def test_get_validated_token(self):
self.backend.get_validated_token(str(token)).payload, token.payload
)

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_get_validated_token_reject_unknown_token(self):
# Should not accept tokens not included in AUTH_TOKEN_CLASSES
sliding_token = SlidingToken()
with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
with self.assertRaises(InvalidToken) as e:
self.backend.get_validated_token(str(sliding_token))

messages = e.exception.detail["messages"]
self.assertEqual(1, len(messages))
self.assertEqual(
{
"token_class": "AccessToken",
"token_type": "access",
"message": "Token has wrong type",
},
messages[0],
)
with self.assertRaises(InvalidToken) as e:
self.backend.get_validated_token(str(sliding_token))

messages = e.exception.detail["messages"]
self.assertEqual(1, len(messages))
self.assertEqual(
{
"token_class": "AccessToken",
"token_type": "access",
"message": "Token has wrong type",
},
messages[0],
)

@override_api_settings(
AUTH_TOKEN_CLASSES=(
"rest_framework_simplejwt.tokens.AccessToken",
"rest_framework_simplejwt.tokens.SlidingToken",
),
)
def test_get_validated_token_accept_known_token(self):
# Should accept tokens included in AUTH_TOKEN_CLASSES
access_token = AccessToken()
sliding_token = SlidingToken()
with override_api_settings(
AUTH_TOKEN_CLASSES=(
"rest_framework_simplejwt.tokens.AccessToken",
"rest_framework_simplejwt.tokens.SlidingToken",
)
):
self.backend.get_validated_token(str(access_token))
self.backend.get_validated_token(str(sliding_token))

self.backend.get_validated_token(str(access_token))
self.backend.get_validated_token(str(sliding_token))

def test_get_user(self):
payload = {"some_other_id": "foo"}
Expand Down
42 changes: 20 additions & 22 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.contrib.auth import get_user_model
from django.urls import reverse
from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED

from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken
Expand All @@ -26,7 +27,7 @@ def setUp(self):
def test_no_authorization(self):
res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertIn("credentials were not provided", res.data["detail"])

def test_wrong_auth_type(self):
Expand All @@ -43,9 +44,12 @@ def test_wrong_auth_type(self):

res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertIn("credentials were not provided", res.data["detail"])

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_expired_token(self):
old_lifetime = AccessToken.lifetime
AccessToken.lifetime = timedelta(seconds=0)
Expand All @@ -63,14 +67,14 @@ def test_expired_token(self):
access = res.data["access"]
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertEqual("token_not_valid", res.data["code"])

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",),
)
def test_user_can_get_sliding_token_and_use_it(self):
res = self.client.post(
reverse("token_obtain_sliding"),
Expand All @@ -83,14 +87,14 @@ def test_user_can_get_sliding_token_and_use_it(self):
token = res.data["token"]
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], token)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
res = self.client.post(
reverse("token_obtain_pair"),
Expand All @@ -105,12 +109,9 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):

self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")

res = self.client.post(
Expand All @@ -122,10 +123,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):

self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")
30 changes: 14 additions & 16 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def test_it_should_return_access_token_if_everything_ok(self):
access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
)

@override_api_settings(
ROTATE_REFRESH_TOKENS=True,
BLACKLIST_AFTER_ROTATION=False,
)
def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
refresh = RefreshToken()

Expand All @@ -298,14 +302,9 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with override_api_settings(
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False
):
with patch(
"rest_framework_simplejwt.tokens.aware_utcnow"
) as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())

access = AccessToken(ser.validated_data["access"])
new_refresh = RefreshToken(ser.validated_data["refresh"])
Expand All @@ -324,6 +323,10 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
)

@override_api_settings(
ROTATE_REFRESH_TOKENS=True,
BLACKLIST_AFTER_ROTATION=True,
)
def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
self,
):
Expand All @@ -342,14 +345,9 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with override_api_settings(
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=True
):
with patch(
"rest_framework_simplejwt.tokens.aware_utcnow"
) as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())

access = AccessToken(ser.validated_data["access"])
new_refresh = RefreshToken(ser.validated_data["refresh"])
Expand Down
20 changes: 10 additions & 10 deletions tests/test_token_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,25 @@ def setUp(self):

super().setUp()

@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
self,
):
with override_api_settings(BLACKLIST_AFTER_ROTATION=True):
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()

serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertFalse(serializer.is_valid())
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertFalse(serializer.is_valid())

@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
self,
):
with override_api_settings(BLACKLIST_AFTER_ROTATION=False):
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()

serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertTrue(serializer.is_valid())
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertTrue(serializer.is_valid())


class TestBigAutoFieldIDMigration(MigrationTestCase):
Expand Down
34 changes: 18 additions & 16 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class TestToken(TestCase):
def setUp(self):
self.token = MyToken()

@classmethod
def setUpTestData(cls):
cls.username = "test_user"
cls.user = User.objects.create_user(
username=cls.username,
password="test_password",
)

def test_init_no_token_type_or_lifetime(self):
class MyTestToken(Token):
pass
Expand Down Expand Up @@ -225,14 +233,14 @@ def test_set_jti(self):
self.assertIn("jti", token)
self.assertNotEqual(old_jti, token["jti"])

@override_api_settings(JTI_CLAIM=None)
def test_optional_jti(self):
with override_api_settings(JTI_CLAIM=None):
token = MyToken()
token = MyToken()
self.assertNotIn("jti", token)

@override_api_settings(TOKEN_TYPE_CLAIM=None)
def test_optional_type_token(self):
with override_api_settings(TOKEN_TYPE_CLAIM=None):
token = MyToken()
token = MyToken()
self.assertNotIn("type", token)

def test_set_exp(self):
Expand Down Expand Up @@ -355,25 +363,19 @@ def test_check_token_if_wrong_type_leeway(self):
token.token_backend.leeway = 0

def test_for_user(self):
username = "test_user"
user = User.objects.create_user(
username=username,
password="test_password",
)
token = MyToken.for_user(self.user)

token = MyToken.for_user(user)

user_id = getattr(user, api_settings.USER_ID_FIELD)
user_id = getattr(self.user, api_settings.USER_ID_FIELD)
if not isinstance(user_id, int):
user_id = str(user_id)

self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id)

@override_api_settings(USER_ID_FIELD="username")
def test_for_user_with_username(self):
# Test with non-int user id
with override_api_settings(USER_ID_FIELD="username"):
token = MyToken.for_user(user)

self.assertEqual(token[api_settings.USER_ID_CLAIM], username)
token = MyToken.for_user(self.user)
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)

def test_get_token_backend(self):
token = MyToken()
Expand Down
Loading

0 comments on commit d2cd59d

Please sign in to comment.