diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 917f3993d..9fa645c2c 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -40,25 +40,19 @@ def test_get_header(self): ) self.assertEqual(self.backend.get_header(request), self.fake_header) - # Should work with the x_access_token - with override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN"): - # Should pull correct header off request when using X_ACCESS_TOKEN - request = self.factory.get( - "/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header - ) - self.assertEqual(self.backend.get_header(request), self.fake_header) - - # Should work for unicode headers when using - request = self.factory.get( - "/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8") - ) - self.assertEqual(self.backend.get_header(request), self.fake_header) + @override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN") + def test_get_header_x_access_token(self): + # Should pull correct header off request when using X_ACCESS_TOKEN + request = self.factory.get("/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header) + self.assertEqual(self.backend.get_header(request), self.fake_header) + + # Should work for unicode headers when using + request = self.factory.get( + "/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8") + ) + self.assertEqual(self.backend.get_header(request), self.fake_header) def test_get_raw_token(self): - # Should return None if header lacks correct type keyword - with override_api_settings(AUTH_HEADER_TYPES="JWT"): - reload(authentication) - self.assertIsNone(self.backend.get_raw_token(self.fake_header)) reload(authentication) # Should return None if an empty AUTHORIZATION header is sent @@ -74,14 +68,21 @@ def test_get_raw_token(self): # Otherwise, should return unvalidated token in header self.assertEqual(self.backend.get_raw_token(self.fake_header), self.fake_token) + @override_api_settings(AUTH_HEADER_TYPES="JWT") + def test_get_raw_token_incorrect_header_keyword(self): + # Should return None if header lacks correct type keyword + # AUTH_HEADER_TYPES is "JWT", but header is "Bearer" + reload(authentication) + self.assertIsNone(self.backend.get_raw_token(self.fake_header)) + + @override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")) + def test_get_raw_token_multi_header_keyword(self): # Should return token if header has one of many valid token types - with override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")): - reload(authentication) - self.assertEqual( - self.backend.get_raw_token(self.fake_header), - self.fake_token, - ) reload(authentication) + self.assertEqual( + self.backend.get_raw_token(self.fake_header), + self.fake_token, + ) def test_get_validated_token(self): # Should raise InvalidToken if token not valid @@ -96,36 +97,39 @@ def test_get_validated_token(self): self.backend.get_validated_token(str(token)).payload, token.payload ) + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + ) + def test_get_validated_token_reject_unknown_token(self): # Should not accept tokens not included in AUTH_TOKEN_CLASSES sliding_token = SlidingToken() - with override_api_settings( - AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",) - ): - with self.assertRaises(InvalidToken) as e: - self.backend.get_validated_token(str(sliding_token)) - - messages = e.exception.detail["messages"] - self.assertEqual(1, len(messages)) - self.assertEqual( - { - "token_class": "AccessToken", - "token_type": "access", - "message": "Token has wrong type", - }, - messages[0], - ) + with self.assertRaises(InvalidToken) as e: + self.backend.get_validated_token(str(sliding_token)) + + messages = e.exception.detail["messages"] + self.assertEqual(1, len(messages)) + self.assertEqual( + { + "token_class": "AccessToken", + "token_type": "access", + "message": "Token has wrong type", + }, + messages[0], + ) + @override_api_settings( + AUTH_TOKEN_CLASSES=( + "rest_framework_simplejwt.tokens.AccessToken", + "rest_framework_simplejwt.tokens.SlidingToken", + ), + ) + def test_get_validated_token_accept_known_token(self): # Should accept tokens included in AUTH_TOKEN_CLASSES access_token = AccessToken() sliding_token = SlidingToken() - with override_api_settings( - AUTH_TOKEN_CLASSES=( - "rest_framework_simplejwt.tokens.AccessToken", - "rest_framework_simplejwt.tokens.SlidingToken", - ) - ): - self.backend.get_validated_token(str(access_token)) - self.backend.get_validated_token(str(sliding_token)) + + self.backend.get_validated_token(str(access_token)) + self.backend.get_validated_token(str(sliding_token)) def test_get_user(self): payload = {"some_other_id": "foo"} diff --git a/tests/test_integration.py b/tests/test_integration.py index 54786d87a..beee3d552 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -2,6 +2,7 @@ from django.contrib.auth import get_user_model from django.urls import reverse +from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.tokens import AccessToken @@ -26,7 +27,7 @@ def setUp(self): def test_no_authorization(self): res = self.view_get() - self.assertEqual(res.status_code, 401) + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) self.assertIn("credentials were not provided", res.data["detail"]) def test_wrong_auth_type(self): @@ -43,9 +44,12 @@ def test_wrong_auth_type(self): res = self.view_get() - self.assertEqual(res.status_code, 401) + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) self.assertIn("credentials were not provided", res.data["detail"]) + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + ) def test_expired_token(self): old_lifetime = AccessToken.lifetime AccessToken.lifetime = timedelta(seconds=0) @@ -63,14 +67,14 @@ def test_expired_token(self): access = res.data["access"] self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) - with override_api_settings( - AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",) - ): - res = self.view_get() + res = self.view_get() - self.assertEqual(res.status_code, 401) + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) self.assertEqual("token_not_valid", res.data["code"]) + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",), + ) def test_user_can_get_sliding_token_and_use_it(self): res = self.client.post( reverse("token_obtain_sliding"), @@ -83,14 +87,14 @@ def test_user_can_get_sliding_token_and_use_it(self): token = res.data["token"] self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], token) - with override_api_settings( - AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",) - ): - res = self.view_get() + res = self.view_get() - self.assertEqual(res.status_code, 200) + self.assertEqual(res.status_code, HTTP_200_OK) self.assertEqual(res.data["foo"], "bar") + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + ) def test_user_can_get_access_and_refresh_tokens_and_use_them(self): res = self.client.post( reverse("token_obtain_pair"), @@ -105,12 +109,9 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) - with override_api_settings( - AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",) - ): - res = self.view_get() + res = self.view_get() - self.assertEqual(res.status_code, 200) + self.assertEqual(res.status_code, HTTP_200_OK) self.assertEqual(res.data["foo"], "bar") res = self.client.post( @@ -122,10 +123,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) - with override_api_settings( - AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",) - ): - res = self.view_get() + res = self.view_get() - self.assertEqual(res.status_code, 200) + self.assertEqual(res.status_code, HTTP_200_OK) self.assertEqual(res.data["foo"], "bar") diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 530f689cd..6db0e3998 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -285,6 +285,10 @@ def test_it_should_return_access_token_if_everything_ok(self): access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME) ) + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=False, + ) def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self): refresh = RefreshToken() @@ -298,14 +302,9 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self): now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 - with override_api_settings( - ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False - ): - with patch( - "rest_framework_simplejwt.tokens.aware_utcnow" - ) as fake_aware_utcnow: - fake_aware_utcnow.return_value = now - self.assertTrue(ser.is_valid()) + with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow: + fake_aware_utcnow.return_value = now + self.assertTrue(ser.is_valid()) access = AccessToken(ser.validated_data["access"]) new_refresh = RefreshToken(ser.validated_data["refresh"]) @@ -324,6 +323,10 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self): datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME), ) + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + ) def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted( self, ): @@ -342,14 +345,9 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 - with override_api_settings( - ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=True - ): - with patch( - "rest_framework_simplejwt.tokens.aware_utcnow" - ) as fake_aware_utcnow: - fake_aware_utcnow.return_value = now - self.assertTrue(ser.is_valid()) + with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow: + fake_aware_utcnow.return_value = now + self.assertTrue(ser.is_valid()) access = AccessToken(ser.validated_data["access"]) new_refresh = RefreshToken(ser.validated_data["refresh"]) diff --git a/tests/test_token_blacklist.py b/tests/test_token_blacklist.py index 67ea7fada..824808145 100644 --- a/tests/test_token_blacklist.py +++ b/tests/test_token_blacklist.py @@ -237,25 +237,25 @@ def setUp(self): super().setUp() + @override_api_settings(BLACKLIST_AFTER_ROTATION=True) def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled( self, ): - with override_api_settings(BLACKLIST_AFTER_ROTATION=True): - refresh_token = RefreshToken.for_user(self.user) - refresh_token.blacklist() + refresh_token = RefreshToken.for_user(self.user) + refresh_token.blacklist() - serializer = TokenVerifySerializer(data={"token": str(refresh_token)}) - self.assertFalse(serializer.is_valid()) + serializer = TokenVerifySerializer(data={"token": str(refresh_token)}) + self.assertFalse(serializer.is_valid()) + @override_api_settings(BLACKLIST_AFTER_ROTATION=False) def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled( self, ): - with override_api_settings(BLACKLIST_AFTER_ROTATION=False): - refresh_token = RefreshToken.for_user(self.user) - refresh_token.blacklist() + refresh_token = RefreshToken.for_user(self.user) + refresh_token.blacklist() - serializer = TokenVerifySerializer(data={"token": str(refresh_token)}) - self.assertTrue(serializer.is_valid()) + serializer = TokenVerifySerializer(data={"token": str(refresh_token)}) + self.assertTrue(serializer.is_valid()) class TestBigAutoFieldIDMigration(MigrationTestCase): diff --git a/tests/test_tokens.py b/tests/test_tokens.py index bc06997d9..9f81a1a76 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -31,6 +31,14 @@ class TestToken(TestCase): def setUp(self): self.token = MyToken() + @classmethod + def setUpTestData(cls): + cls.username = "test_user" + cls.user = User.objects.create_user( + username=cls.username, + password="test_password", + ) + def test_init_no_token_type_or_lifetime(self): class MyTestToken(Token): pass @@ -225,14 +233,14 @@ def test_set_jti(self): self.assertIn("jti", token) self.assertNotEqual(old_jti, token["jti"]) + @override_api_settings(JTI_CLAIM=None) def test_optional_jti(self): - with override_api_settings(JTI_CLAIM=None): - token = MyToken() + token = MyToken() self.assertNotIn("jti", token) + @override_api_settings(TOKEN_TYPE_CLAIM=None) def test_optional_type_token(self): - with override_api_settings(TOKEN_TYPE_CLAIM=None): - token = MyToken() + token = MyToken() self.assertNotIn("type", token) def test_set_exp(self): @@ -355,25 +363,19 @@ def test_check_token_if_wrong_type_leeway(self): token.token_backend.leeway = 0 def test_for_user(self): - username = "test_user" - user = User.objects.create_user( - username=username, - password="test_password", - ) + token = MyToken.for_user(self.user) - token = MyToken.for_user(user) - - user_id = getattr(user, api_settings.USER_ID_FIELD) + user_id = getattr(self.user, api_settings.USER_ID_FIELD) if not isinstance(user_id, int): user_id = str(user_id) self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id) + @override_api_settings(USER_ID_FIELD="username") + def test_for_user_with_username(self): # Test with non-int user id - with override_api_settings(USER_ID_FIELD="username"): - token = MyToken.for_user(user) - - self.assertEqual(token[api_settings.USER_ID_CLAIM], username) + token = MyToken.for_user(self.user) + self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username) def test_get_token_backend(self): token = MyToken() diff --git a/tests/test_views.py b/tests/test_views.py index 8a9e16f26..b1fc80113 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,5 +1,4 @@ from datetime import timedelta -from importlib import reload from unittest.mock import patch from django.contrib.auth import get_user_model @@ -93,20 +92,18 @@ def test_update_last_login(self): user = User.objects.get(username=self.username) self.assertEqual(user.last_login, None) + @override_api_settings(UPDATE_LAST_LOGIN=True) + def test_update_last_login_updated(self): # verify last_login is updated - with override_api_settings(UPDATE_LAST_LOGIN=True): - reload(serializers) - self.view_post( - data={ - User.USERNAME_FIELD: self.username, - "password": self.password, - } - ) - user = User.objects.get(username=self.username) - self.assertIsNotNone(user.last_login) - self.assertGreaterEqual(timezone.now(), user.last_login) - - reload(serializers) + self.view_post( + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + } + ) + user = User.objects.get(username=self.username) + self.assertIsNotNone(user.last_login) + self.assertGreaterEqual(timezone.now(), user.last_login) class TestTokenRefreshView(APIViewTestCase): @@ -233,20 +230,18 @@ def test_update_last_login(self): user = User.objects.get(username=self.username) self.assertEqual(user.last_login, None) + @override_api_settings(UPDATE_LAST_LOGIN=True) + def test_update_last_login_updated(self): # verify last_login is updated - with override_api_settings(UPDATE_LAST_LOGIN=True): - reload(serializers) - self.view_post( - data={ - User.USERNAME_FIELD: self.username, - "password": self.password, - } - ) - user = User.objects.get(username=self.username) - self.assertIsNotNone(user.last_login) - self.assertGreaterEqual(timezone.now(), user.last_login) - - reload(serializers) + self.view_post( + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + } + ) + user = User.objects.get(username=self.username) + self.assertIsNotNone(user.last_login) + self.assertGreaterEqual(timezone.now(), user.last_login) class TestTokenRefreshSlidingView(APIViewTestCase): diff --git a/tests/utils.py b/tests/utils.py index c3d710a2e..9e4c2c4cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -62,23 +62,25 @@ def override_api_settings(**settings): except AttributeError: pass - yield - - for k in settings.keys(): - # Delete temporary settings - api_settings.user_settings.pop(k) - - # Restore saved settings - try: - api_settings.user_settings[k] = old_settings[k] - except KeyError: - pass - - # Delete any cached settings - try: - delattr(api_settings, k) - except AttributeError: - pass + try: + yield + + finally: + for k in settings.keys(): + # Delete temporary settings + api_settings.user_settings.pop(k) + + # Restore saved settings + try: + api_settings.user_settings[k] = old_settings[k] + except KeyError: + pass + + # Delete any cached settings + try: + delattr(api_settings, k) + except AttributeError: + pass class MigrationTestCase(TransactionTestCase):