diff --git a/api/environments/identities/tests/__init__.py b/api/environments/identities/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/api/environments/identities/traits/tests/__init__.py b/api/environments/identities/traits/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/api/environments/identities/traits/tests/test_views.py b/api/environments/identities/traits/tests/test_views.py deleted file mode 100644 index 1f7f68c89c3d..000000000000 --- a/api/environments/identities/traits/tests/test_views.py +++ /dev/null @@ -1,802 +0,0 @@ -import json -from unittest import mock -from unittest.case import TestCase - -import pytest -from core.constants import INTEGER, STRING -from django.test import override_settings -from django.urls import reverse -from rest_framework import status -from rest_framework.test import APIClient, APITestCase - -from environments.identities.models import Identity -from environments.identities.traits.constants import ( - TRAIT_STRING_VALUE_MAX_LENGTH, -) -from environments.identities.traits.models import Trait -from environments.models import Environment, EnvironmentAPIKey -from organisations.models import Organisation, OrganisationRole -from projects.models import Project -from util.tests import Helper - - -class SDKTraitsTest(APITestCase): - JSON = "application/json" - - def setUp(self) -> None: - self.organisation = Organisation.objects.create(name="Test organisation") - project = Project.objects.create( - name="Test project", organisation=self.organisation, enable_dynamo_db=True - ) - self.environment = Environment.objects.create( - name="Test environment", project=project - ) - self.identity = Identity.objects.create( - identifier="test-user", environment=self.environment - ) - self.client.credentials(HTTP_X_ENVIRONMENT_KEY=self.environment.api_key) - self.trait_key = "trait_key" - self.trait_value = "trait_value" - - def test_can_set_trait_for_an_identity(self): - # Given - url = reverse("api-v1:sdk-traits-list") - - # When - res = self.client.post( - url, data=self._generate_json_trait_data(), content_type=self.JSON - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert Trait.objects.filter( - identity=self.identity, trait_key=self.trait_key - ).exists() - - def test_cannot_set_trait_for_an_identity_for_organisations_without_persistence( - self, - ): - # Given - url = reverse("api-v1:sdk-traits-list") - - # an organisation that is configured to not store traits - self.organisation.persist_trait_data = False - self.organisation.save() - - # When - response = self.client.post( - url, data=self._generate_json_trait_data(), content_type=self.JSON - ) - - # Then - # the request fails - assert response.status_code == status.HTTP_403_FORBIDDEN - response_json = response.json() - assert response_json["detail"] == ( - "Organisation is not authorised to store traits." - ) - - # and no traits are stored - assert Trait.objects.count() == 0 - - def test_can_set_trait_with_boolean_value_for_an_identity(self): - # Given - url = reverse("api-v1:sdk-traits-list") - trait_value = True - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(trait_value=trait_value), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert ( - Trait.objects.get( - identity=self.identity, trait_key=self.trait_key - ).get_trait_value() - == trait_value - ) - - def test_can_set_trait_with_identity_value_for_an_identity(self): - # Given - url = reverse("api-v1:sdk-traits-list") - trait_value = 12 - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(trait_value=trait_value), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert ( - Trait.objects.get( - identity=self.identity, trait_key=self.trait_key - ).get_trait_value() - == trait_value - ) - - def test_can_set_trait_with_float_value_for_an_identity(self): - # Given - url = reverse("api-v1:sdk-traits-list") - float_trait_value = 10.5 - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(trait_value=float_trait_value), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert ( - Trait.objects.get( - identity=self.identity, trait_key=self.trait_key - ).get_trait_value() - == float_trait_value - ) - - def test_add_trait_creates_identity_if_it_doesnt_exist(self): - # Given - url = reverse("api-v1:sdk-traits-list") - identifier = "new-identity" - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(identifier=identifier), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert Identity.objects.filter( - identifier=identifier, environment=self.environment - ).exists() - - # and - assert Trait.objects.filter( - identity__identifier=identifier, trait_key=self.trait_key - ).exists() - - def test_trait_is_updated_if_already_exists(self): - # Given - url = reverse("api-v1:sdk-traits-list") - trait = Trait.objects.create( - trait_key=self.trait_key, - value_type=STRING, - string_value=self.trait_value, - identity=self.identity, - ) - new_value = "Some new value" - - # When - self.client.post( - url, - data=self._generate_json_trait_data(trait_value=new_value), - content_type=self.JSON, - ) - - # Then - trait.refresh_from_db() - assert trait.get_trait_value() == new_value - - def test_increment_value_increments_trait_value_if_value_positive_integer(self): - # Given - initial_value = 2 - increment_by = 2 - - url = reverse("api-v1:sdk-traits-increment-value") - trait = Trait.objects.create( - identity=self.identity, - trait_key=self.trait_key, - value_type=INTEGER, - integer_value=initial_value, - ) - data = { - "trait_key": self.trait_key, - "identifier": self.identity.identifier, - "increment_by": increment_by, - } - - # When - self.client.post(url, data=data) - - # Then - trait.refresh_from_db() - assert trait.get_trait_value() == initial_value + increment_by - - def test_increment_value_decrements_trait_value_if_value_negative_integer(self): - # Given - initial_value = 2 - increment_by = -2 - - url = reverse("api-v1:sdk-traits-increment-value") - trait = Trait.objects.create( - identity=self.identity, - trait_key=self.trait_key, - value_type=INTEGER, - integer_value=initial_value, - ) - data = { - "trait_key": self.trait_key, - "identifier": self.identity.identifier, - "increment_by": increment_by, - } - - # When - self.client.post(url, data=data) - - # Then - trait.refresh_from_db() - assert trait.get_trait_value() == initial_value + increment_by - - def test_increment_value_initialises_trait_with_a_value_of_zero_if_it_doesnt_exist( - self, - ): - # Given - increment_by = 1 - - url = reverse("api-v1:sdk-traits-increment-value") - data = { - "trait_key": self.trait_key, - "identifier": self.identity.identifier, - "increment_by": increment_by, - } - - # When - self.client.post(url, data=data) - - # Then - trait = Trait.objects.get(trait_key=self.trait_key, identity=self.identity) - assert trait.get_trait_value() == increment_by - - def test_increment_value_returns_400_if_trait_value_not_integer(self): - # Given - url = reverse("api-v1:sdk-traits-increment-value") - Trait.objects.create( - identity=self.identity, - trait_key=self.trait_key, - value_type=STRING, - string_value="str", - ) - data = { - "trait_key": self.trait_key, - "identifier": self.identity.identifier, - "increment_by": 2, - } - - # When - res = self.client.post(url, data=data) - - # Then - assert res.status_code == status.HTTP_400_BAD_REQUEST - - def test_set_trait_with_too_long_string_value_returns_400(self): - # Given - url = reverse("api-v1:sdk-traits-list") - trait_value = "t" * (TRAIT_STRING_VALUE_MAX_LENGTH + 1) - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(trait_value=trait_value), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_400_BAD_REQUEST - assert ( - f"Value string is too long. Must be less than {TRAIT_STRING_VALUE_MAX_LENGTH} character" - == res.json()["trait_value"][0] - ) - - def test_can_set_trait_with_bad_value_for_an_identity(self): - # Given - url = reverse("api-v1:sdk-traits-list") - bad_trait_value = {"foo": "bar"} - - # When - res = self.client.post( - url, - data=self._generate_json_trait_data(trait_value=bad_trait_value), - content_type=self.JSON, - ) - - # Then - assert res.status_code == status.HTTP_200_OK - - # and - assert Trait.objects.get( - identity=self.identity, trait_key=self.trait_key - ).get_trait_value() == str(bad_trait_value) - - def test_bulk_create_traits(self): - # Given - num_traits = 20 - url = reverse("api-v1:sdk-traits-bulk-create") - traits = [ - self._generate_trait_data(trait_key=f"trait_{i}", identifier="user_{i}") - for i in range(num_traits) - ] - identifiers = [trait["identity"]["identifier"] for trait in traits] - - # When - response = self.client.put( - url, data=json.dumps(traits), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_200_OK - assert ( - Trait.objects.filter(identity__identifier__in=identifiers).count() - == num_traits - ) - - def test_bulk_create_traits_when_bad_trait_value_sent_then_trait_value_stringified( - self, - ): - # Given - num_traits = 5 - url = reverse("api-v1:sdk-traits-bulk-create") - traits = [ - self._generate_trait_data(trait_key=f"trait_{i}") for i in range(num_traits) - ] - - # add some bad data to test - bad_trait_key = "trait_999" - bad_trait_value = {"foo": "bar"} - traits.append( - { - "trait_value": bad_trait_value, - "trait_key": bad_trait_key, - "identity": {"identifier": self.identity.identifier}, - } - ) - - # When - response = self.client.put( - url, data=json.dumps(traits), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_200_OK - assert Trait.objects.filter(identity=self.identity).count() == num_traits + 1 - - # and - assert Trait.objects.get( - identity=self.identity, trait_key=bad_trait_key - ).get_trait_value() == str(bad_trait_value) - - def test_sending_null_value_in_bulk_create_deletes_trait_for_identity(self): - # Given - url = reverse("api-v1:sdk-traits-bulk-create") - trait_to_delete = Trait.objects.create( - trait_key=self.trait_key, - value_type=STRING, - string_value=self.trait_value, - identity=self.identity, - ) - trait_key_to_keep = "another_trait_key" - trait_to_keep = Trait.objects.create( - trait_key=trait_key_to_keep, - value_type=STRING, - string_value="value is irrelevant", - identity=self.identity, - ) - data = [ - { - "identity": {"identifier": self.identity.identifier}, - "trait_key": self.trait_key, - "trait_value": None, - } - ] - - # When - response = self.client.put( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - # the request is successful - assert response.status_code == status.HTTP_200_OK - - # and the trait is deleted - assert not Trait.objects.filter(id=trait_to_delete.id).exists() - - # but the trait missing from the request is left untouched - assert Trait.objects.filter(id=trait_to_keep.id).exists() - - def test_bulk_create_traits_when_float_value_sent_then_trait_value_correct(self): - # Given - url = reverse("api-v1:sdk-traits-bulk-create") - traits = [] - - # add float value trait - float_trait_key = "float_key_999" - float_trait_value = 45.88 - traits.append( - { - "trait_value": float_trait_value, - "trait_key": float_trait_key, - "identity": {"identifier": self.identity.identifier}, - } - ) - - # When - response = self.client.put( - url, data=json.dumps(traits), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_200_OK - assert Trait.objects.filter(identity=self.identity).count() == 1 - - # and - assert ( - Trait.objects.get( - identity=self.identity, trait_key=float_trait_key - ).get_trait_value() - == float_trait_value - ) - - @override_settings(EDGE_API_URL="http://localhost") - @mock.patch("environments.identities.traits.views.forward_trait_request") - def test_post_trait_calls_forward_trait_request_with_correct_arguments( - self, mocked_forward_trait_request - ): - # Given - url = reverse("api-v1:sdk-traits-list") - data = self._generate_json_trait_data() - - # When - self.client.post(url, data=data, content_type=self.JSON) - - # Then - args, kwargs = mocked_forward_trait_request.delay.call_args_list[0] - assert args == () - assert kwargs["args"][0] == "POST" - assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key - assert kwargs["args"][2] == self.environment.project.id - assert kwargs["args"][3] == json.loads(data) - - @override_settings(EDGE_API_URL="http://localhost") - @mock.patch("environments.identities.traits.views.forward_trait_request") - def test_increment_value_calls_forward_trait_request_with_correct_arguments( - self, mocked_forward_trait_request - ): - # Given - url = reverse("api-v1:sdk-traits-increment-value") - data = { - "trait_key": self.trait_key, - "identifier": self.identity.identifier, - "increment_by": 1, - } - - # When - self.client.post(url, data=data) - - # Then - args, kwargs = mocked_forward_trait_request.delay.call_args_list[0] - assert args == () - assert kwargs["args"][0] == "POST" - assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key - assert kwargs["args"][2] == self.environment.project.id - - # and the structure of payload was correct - assert kwargs["args"][3]["identity"]["identifier"] == data["identifier"] - assert kwargs["args"][3]["trait_key"] == data["trait_key"] - assert kwargs["args"][3]["trait_value"] - - @override_settings(EDGE_API_URL="http://localhost") - @mock.patch("environments.identities.traits.views.forward_trait_requests") - def test_bulk_create_traits_calls_forward_trait_request_with_correct_arguments( - self, mocked_forward_trait_requests - ): - # Given - url = reverse("api-v1:sdk-traits-bulk-create") - request_data = [ - { - "identity": {"identifier": "test_user_123"}, - "trait_key": "key", - "trait_value": "value", - }, - { - "identity": {"identifier": "test_user_123"}, - "trait_key": "key1", - "trait_value": "value1", - }, - ] - - # When - self.client.put( - url, data=json.dumps(request_data), content_type="application/json" - ) - - # Then - - # Then - args, kwargs = mocked_forward_trait_requests.delay.call_args_list[0] - assert args == () - assert kwargs["args"][0] == "PUT" - assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key - assert kwargs["args"][2] == self.environment.project.id - assert kwargs["args"][3] == request_data - - def test_create_trait_returns_403_if_client_cannot_set_traits(self): - # Given - url = reverse("api-v1:sdk-traits-list") - data = { - "identity": {"identifier": self.identity.identifier}, - "trait_key": "foo", - "trait_value": "bar", - } - - self.environment.allow_client_traits = False - self.environment.save() - - # When - response = self.client.post( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_400_BAD_REQUEST - - def test_server_key_can_create_trait_if_not_allow_client_traits(self): - # Given - url = reverse("api-v1:sdk-traits-list") - data = { - "identity": {"identifier": self.identity.identifier}, - "trait_key": "foo", - "trait_value": "bar", - } - - server_api_key = EnvironmentAPIKey.objects.create(environment=self.environment) - self.client.credentials(HTTP_X_ENVIRONMENT_KEY=server_api_key.key) - - self.environment.allow_client_traits = False - self.environment.save() - - # When - response = self.client.post( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_200_OK - - def test_bulk_create_traits_returns_403_if_client_cannot_set_traits(self): - # Given - url = reverse("api-v1:sdk-traits-bulk-create") - data = [ - { - "identity": {"identifier": self.identity.identifier}, - "trait_key": "foo", - "trait_value": "bar", - } - ] - - self.environment.allow_client_traits = False - self.environment.save() - - # When - response = self.client.put( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_400_BAD_REQUEST - - def test_server_key_can_bulk_create_traits_if_not_allow_client_traits(self): - # Given - url = reverse("api-v1:sdk-traits-bulk-create") - data = [ - { - "identity": {"identifier": self.identity.identifier}, - "trait_key": "foo", - "trait_value": "bar", - } - ] - - server_api_key = EnvironmentAPIKey.objects.create(environment=self.environment) - self.client.credentials(HTTP_X_ENVIRONMENT_KEY=server_api_key.key) - - self.environment.allow_client_traits = False - self.environment.save() - - # When - response = self.client.put( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - assert response.status_code == status.HTTP_200_OK - - def _generate_trait_data(self, identifier=None, trait_key=None, trait_value=None): - identifier = identifier or self.identity.identifier - trait_key = trait_key or self.trait_key - trait_value = trait_value or self.trait_value - - return { - "identity": {"identifier": identifier}, - "trait_key": trait_key, - "trait_value": trait_value, - } - - def _generate_json_trait_data( - self, identifier=None, trait_key=None, trait_value=None - ): - return json.dumps(self._generate_trait_data(identifier, trait_key, trait_value)) - - -@pytest.mark.django_db -class TraitViewSetTestCase(TestCase): - def setUp(self) -> None: - self.client = APIClient() - user = Helper.create_ffadminuser() - self.client.force_authenticate(user=user) - - organisation = Organisation.objects.create(name="Test org") - user.add_organisation(organisation, OrganisationRole.ADMIN) - - self.project = Project.objects.create( - name="Test project", organisation=organisation - ) - self.environment = Environment.objects.create( - name="Test environment", project=self.project - ) - self.identity = Identity.objects.create( - identifier="test-user", environment=self.environment - ) - - def test_delete_trait_only_deletes_single_trait_if_query_param_not_provided(self): - # Given - trait_key = "trait_key" - trait_value = "trait_value" - identity_2 = Identity.objects.create( - identifier="test-user-2", environment=self.environment - ) - - trait = Trait.objects.create( - identity=self.identity, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - trait_2 = Trait.objects.create( - identity=identity_2, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - - url = reverse( - "api-v1:environments:identities-traits-detail", - args=[self.environment.api_key, self.identity.id, trait.id], - ) - - # When - self.client.delete(url) - - # Then - assert not Trait.objects.filter(pk=trait.id).exists() - - # and - assert Trait.objects.filter(pk=trait_2.id).exists() - - def test_delete_trait_deletes_all_traits_if_query_param_provided(self): - # Given - trait_key = "trait_key" - trait_value = "trait_value" - identity_2 = Identity.objects.create( - identifier="test-user-2", environment=self.environment - ) - - trait = Trait.objects.create( - identity=self.identity, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - trait_2 = Trait.objects.create( - identity=identity_2, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - - base_url = reverse( - "api-v1:environments:identities-traits-detail", - args=[self.environment.api_key, self.identity.id, trait.id], - ) - url = base_url + "?deleteAllMatchingTraits=true" - - # When - self.client.delete(url) - - # Then - assert not Trait.objects.filter(pk=trait.id).exists() - - # and - assert not Trait.objects.filter(pk=trait_2.id).exists() - - def test_delete_trait_only_deletes_traits_in_current_environment(self): - # Given - environment_2 = Environment.objects.create( - name="Test environment", project=self.project - ) - trait_key = "trait_key" - trait_value = "trait_value" - identity_2 = Identity.objects.create( - identifier="test-user-2", environment=environment_2 - ) - - trait = Trait.objects.create( - identity=self.identity, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - trait_2 = Trait.objects.create( - identity=identity_2, - trait_key=trait_key, - value_type=STRING, - string_value=trait_value, - ) - - base_url = reverse( - "api-v1:environments:identities-traits-detail", - args=[self.environment.api_key, self.identity.id, trait.id], - ) - url = base_url + "?deleteAllMatchingTraits=true" - - # When - self.client.delete(url) - - # Then - assert not Trait.objects.filter(pk=trait.id).exists() - - # and - assert Trait.objects.filter(pk=trait_2.id).exists() - - -def test_set_trait_for_an_identity_is_not_throttled_by_user_throttle( - settings, identity, environment, api_client -): - # Given - settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} - - api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) - - url = reverse("api-v1:sdk-traits-list") - data = { - "identity": {"identifier": identity.identifier}, - "trait_key": "key", - "trait_value": "value", - } - - # When - for _ in range(10): - res = api_client.post( - url, data=json.dumps(data), content_type="application/json" - ) - - # Then - assert res.status_code == status.HTTP_200_OK diff --git a/api/environments/permissions/tests/__init__.py b/api/environments/permissions/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/api/environments/tests/__init__.py b/api/environments/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/api/environments/tests/test_models.py b/api/environments/tests/test_models.py deleted file mode 100644 index b288174e7e33..000000000000 --- a/api/environments/tests/test_models.py +++ /dev/null @@ -1,383 +0,0 @@ -from copy import copy -from datetime import timedelta -from unittest import mock - -import pytest -from core.constants import STRING -from django.test import TestCase, override_settings -from django.utils import timezone - -from audit.models import AuditLog -from audit.related_object_type import RelatedObjectType -from environments.identities.models import Identity -from environments.models import ( - Environment, - EnvironmentAPIKey, - environment_cache, -) -from features.feature_types import MULTIVARIATE -from features.models import Feature, FeatureState -from features.multivariate.models import MultivariateFeatureOption -from organisations.models import Organisation -from projects.models import Project - - -@pytest.mark.django_db -class EnvironmentTestCase(TestCase): - def setUp(self): - self.organisation = Organisation.objects.create(name="Test Org") - self.project = Project.objects.create( - name="Test Project", organisation=self.organisation - ) - self.feature = Feature.objects.create(name="Test Feature", project=self.project) - # The environment is initialised in a non-saved state as we want to test the save - # functionality. - self.environment = Environment(name="Test Environment", project=self.project) - - def test_environment_should_be_created_with_feature_states(self): - # Given - set up data - - # When - self.environment.save() - - # Then - feature_states = FeatureState.objects.filter(environment=self.environment) - assert hasattr(self.environment, "api_key") - assert feature_states.count() == 1 - - def test_on_creation_save_feature_states_get_created(self): - # These should be no feature states before saving - self.assertEqual(FeatureState.objects.count(), 0) - - self.environment.save() - - # On the first save a new feature state should be created - self.assertEqual(FeatureState.objects.count(), 1) - - def test_on_update_save_feature_states_get_updated_not_created(self): - self.environment.save() - - self.feature.default_enabled = True - self.feature.save() - self.environment.save() - - self.assertEqual(FeatureState.objects.count(), 1) - - def test_on_creation_save_feature_is_created_with_the_correct_default(self): - self.environment.save() - self.assertFalse(FeatureState.objects.get().enabled) - - def test_clone_does_not_modify_the_original_instance(self): - # Given - self.environment.save() - - # When - clone = self.environment.clone(name="Cloned env") - - # Then - self.assertNotEqual(clone.name, self.environment.name) - self.assertNotEqual(clone.api_key, self.environment.api_key) - - def test_clone_save_creates_feature_states(self): - # Given - self.environment.save() - - # When - clone = self.environment.clone(name="Cloned env") - - # Then - feature_states = FeatureState.objects.filter(environment=clone) - assert feature_states.count() == 1 - - def test_clone_does_not_modify_source_feature_state(self): - # Given - self.environment.save() - source_feature_state_before_clone = FeatureState.objects.filter( - environment=self.environment - ).first() - - # When - self.environment.clone(name="Cloned env") - source_feature_state_after_clone = FeatureState.objects.filter( - environment=self.environment - ).first() - - # Then - assert source_feature_state_before_clone == source_feature_state_after_clone - - def test_clone_does_not_create_identity(self): - # Given - self.environment.save() - Identity.objects.create( - environment=self.environment, identifier="test_identity" - ) - # When - clone = self.environment.clone(name="Cloned env") - - # Then - assert clone.identities.count() == 0 - - def test_clone_clones_the_feature_states(self): - # Given - self.environment.save() - - # Enable the feature in the source environment - self.environment.feature_states.update(enabled=True) - - # When - clone = self.environment.clone(name="Cloned env") - - # Then - assert clone.feature_states.first().enabled is True - - def test_clone_clones_multivariate_feature_state_values(self): - # Given - self.environment.save() - - mv_feature = Feature.objects.create( - type=MULTIVARIATE, - name="mv_feature", - initial_value="foo", - project=self.project, - ) - variant_1 = MultivariateFeatureOption.objects.create( - feature=mv_feature, - default_percentage_allocation=10, - type=STRING, - string_value="bar", - ) - - # When - clone = self.environment.clone(name="Cloned env") - - # Then - cloned_mv_feature_state = clone.feature_states.get(feature=mv_feature) - assert cloned_mv_feature_state.multivariate_feature_state_values.count() == 1 - - original_mv_fs_value = FeatureState.objects.get( - environment=self.environment, feature=mv_feature - ).multivariate_feature_state_values.first() - cloned_mv_fs_value = ( - cloned_mv_feature_state.multivariate_feature_state_values.first() - ) - - assert original_mv_fs_value != cloned_mv_fs_value - assert ( - original_mv_fs_value.multivariate_feature_option - == cloned_mv_fs_value.multivariate_feature_option - == variant_1 - ) - assert ( - original_mv_fs_value.percentage_allocation - == cloned_mv_fs_value.percentage_allocation - == 10 - ) - - @mock.patch("environments.models.environment_cache") - def test_get_from_cache_stores_environment_in_cache_on_success(self, mock_cache): - # Given - self.environment.save() - mock_cache.get.return_value = None - - # When - environment = Environment.get_from_cache(self.environment.api_key) - - # Then - assert environment == self.environment - mock_cache.set.assert_called_with( - self.environment.api_key, self.environment, timeout=60 - ) - - def test_get_from_cache_returns_None_if_no_matching_environment(self): - # Given - api_key = "no-matching-env" - - # When - env = Environment.get_from_cache(api_key) - - # Then - assert env is None - - def test_get_from_cache_accepts_environment_api_key_model_key(self): - # Given - self.environment.save() - api_key = EnvironmentAPIKey.objects.create( - name="Some key", environment=self.environment - ) - - # When - environment_from_cache = Environment.get_from_cache(api_key=api_key.key) - - # Then - assert environment_from_cache == self.environment - - def test_get_from_cache_with_null_environment_key_returns_null(self): - # Given - self.environment.save() - - # When - environment = Environment.get_from_cache(None) - - # Then - assert environment is None - - @override_settings( - CACHE_BAD_ENVIRONMENTS_SECONDS=60, CACHE_BAD_ENVIRONMENTS_AFTER_FAILURES=1 - ) - def test_get_from_cache_does_not_hit_database_if_api_key_in_bad_env_cache(self): - # Given - api_key = "bad-key" - - # When - with self.assertNumQueries(1): - [Environment.get_from_cache(api_key) for _ in range(10)] - - -def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key( - environment, -): - assert ( - EnvironmentAPIKey.objects.create( - environment=environment, - key="ser.random_key", - name="test_key", - ).is_valid - is True - ) - - -def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key_with_expired_date_in_future( - environment, -): - assert ( - EnvironmentAPIKey.objects.create( - environment=environment, - key="ser.random_key", - name="test_key", - expires_at=timezone.now() + timedelta(days=5), - ).is_valid - is True - ) - - -def test_environment_api_key_model_is_valid_is_false_for_expired_active_key( - environment, -): - assert ( - EnvironmentAPIKey.objects.create( - environment=environment, - key="ser.random_key", - name="test_key", - expires_at=timezone.now() - timedelta(seconds=1), - ).is_valid - is False - ) - - -def test_environment_api_key_model_is_valid_is_false_for_non_expired_inactive_key( - environment, -): - assert ( - EnvironmentAPIKey.objects.create( - environment=environment, key="ser.random_key", name="test_key", active=False - ).is_valid - is False - ) - - -def test_existence_of_multiple_environment_api_keys_does_not_break_get_from_cache( - environment, -): - # Given - environment_api_keys = [ - EnvironmentAPIKey.objects.create(environment=environment, name=f"test_key_{i}") - for i in range(2) - ] - - # When - retrieved_environments = [ - Environment.get_from_cache(environment.api_key), - *[ - Environment.get_from_cache(environment_api_key.key) - for environment_api_key in environment_api_keys - ], - ] - - # Then - assert all( - retrieved_environment == environment - for retrieved_environment in retrieved_environments - ) - - -def test_get_from_cache_sets_the_cache_correctly_with_environment_api_key( - environment, environment_api_key, mocker -): - # When - returned_environment = Environment.get_from_cache(environment_api_key.key) - - # Then - assert returned_environment == environment - - # and - assert environment == environment_cache.get(environment_api_key.key) - - -def test_updated_at_gets_updated_when_environment_audit_log_created(environment): - # When - audit_log = AuditLog.objects.create( - environment=environment, project=environment.project, log="random_audit_log" - ) - - # Then - environment.refresh_from_db() - assert environment.updated_at == audit_log.created_date - - -def test_updated_at_gets_updated_when_project_audit_log_created(environment): - # When - audit_log = AuditLog.objects.create( - project=environment.project, log="random_audit_log" - ) - environment.refresh_from_db() - # Then - assert environment.updated_at == audit_log.created_date - - -def test_change_request_audit_logs_does_not_update_updated_at(environment): - # Given - updated_at_before_audit_log = environment.updated_at - - # When - audit_log = AuditLog.objects.create( - environment=environment, - log="random_test", - related_object_type=RelatedObjectType.CHANGE_REQUEST.name, - ) - - # Then - assert environment.updated_at == updated_at_before_audit_log - assert environment.updated_at != audit_log.created_date - - -def test_save_environment_clears_environment_cache(mocker, project): - # Given - mock_environment_cache = mocker.patch("environments.models.environment_cache") - environment = Environment.objects.create(name="test environment", project=project) - - # perform an update of the name to verify basic functionality - environment.name = "updated" - environment.save() - - # and update the api key to verify that the original api key is used to clear cache - old_key = copy(environment.api_key) - new_key = "some-new-key" - environment.api_key = new_key - - # When - environment.save() - - # Then - mock_calls = mock_environment_cache.delete.mock_calls - assert len(mock_calls) == 2 - assert mock_calls[0][1][0] == mock_calls[1][1][0] == old_key diff --git a/api/features/feature_segments/tests/__init__.py b/api/features/feature_segments/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/api/environments/dynamodb/tests/test_dynamo_environment_wrapper.py b/api/tests/unit/environments/dynamodb/test_unit_dynamo_environment_wrapper.py similarity index 100% rename from api/environments/dynamodb/tests/test_dynamo_environment_wrapper.py rename to api/tests/unit/environments/dynamodb/test_unit_dynamo_environment_wrapper.py diff --git a/api/environments/dynamodb/tests/test_dynamodb_environment_api_key_wrapper.py b/api/tests/unit/environments/dynamodb/test_unit_dynamodb_environment_api_key_wrapper.py similarity index 100% rename from api/environments/dynamodb/tests/test_dynamodb_environment_api_key_wrapper.py rename to api/tests/unit/environments/dynamodb/test_unit_dynamodb_environment_api_key_wrapper.py diff --git a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py b/api/tests/unit/environments/dynamodb/test_unit_dynamodb_identity_wrapper.py similarity index 100% rename from api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py rename to api/tests/unit/environments/dynamodb/test_unit_dynamodb_identity_wrapper.py diff --git a/api/environments/dynamodb/tests/test_migrator.py b/api/tests/unit/environments/dynamodb/test_unit_migrator.py similarity index 100% rename from api/environments/dynamodb/tests/test_migrator.py rename to api/tests/unit/environments/dynamodb/test_unit_migrator.py diff --git a/api/environments/dynamodb/tests/types/test_dynamodb_project_metadata.py b/api/tests/unit/environments/dynamodb/types/test_unit_dynamodb_project_metadata.py similarity index 100% rename from api/environments/dynamodb/tests/types/test_dynamodb_project_metadata.py rename to api/tests/unit/environments/dynamodb/types/test_unit_dynamodb_project_metadata.py diff --git a/api/environments/identities/tests/helpers.py b/api/tests/unit/environments/identities/helpers.py similarity index 100% rename from api/environments/identities/tests/helpers.py rename to api/tests/unit/environments/identities/helpers.py diff --git a/api/tests/unit/environments/identities/test_identities_models.py b/api/tests/unit/environments/identities/test_identities_models.py deleted file mode 100644 index d6bd5da87299..000000000000 --- a/api/tests/unit/environments/identities/test_identities_models.py +++ /dev/null @@ -1,89 +0,0 @@ -import typing - -from django.utils import timezone - -from environments.identities.models import Identity -from features.models import Feature, FeatureState - -if typing.TYPE_CHECKING: - from environments.models import Environment - - -def test_identity_get_all_feature_states_gets_latest_committed_version(environment): - # Given - identity = Identity.objects.create(identifier="identity", environment=environment) - - feature = Feature.objects.create( - name="versioned_feature", - project=environment.project, - default_enabled=False, - initial_value="v1", - ) - - now = timezone.now() - - # creating the feature above will have created a feature state with version=1, - # now we create 2 more versions... - # one of which is live... - feature_state_v2 = FeatureState.objects.create( - feature=feature, - version=2, - live_from=now, - enabled=True, - environment=environment, - ) - feature_state_v2.feature_state_value.string_value = "v2" - feature_state_v2.feature_state_value.save() - - # and one which isn't - not_live_feature_state = FeatureState.objects.create( - feature=feature, - version=None, - live_from=None, - enabled=False, - environment=environment, - ) - not_live_feature_state.feature_state_value.string_value = "v3" - not_live_feature_state.feature_state_value.save() - - # When - identity_feature_states = identity.get_all_feature_states() - - # Then - identity_feature_state = next( - filter(lambda fs: fs.feature == feature, identity_feature_states) - ) - assert identity_feature_state.get_feature_state_value() == "v2" - - -def test_get_hash_key_with_use_identity_composite_key_for_hashing_enabled( - identity: Identity, -): - assert ( - identity.get_hash_key(use_identity_composite_key_for_hashing=True) - == f"{identity.environment.api_key}_{identity.identifier}" - ) - - -def test_get_hash_key_with_use_identity_composite_key_for_hashing_disabled( - identity: Identity, -): - assert identity.get_hash_key(use_identity_composite_key_for_hashing=False) == str( - identity.id - ) - - -def test_identity_get_all_feature_states__returns_identity_override__when_v2_feature_versioning_enabled( - identity: Identity, environment_v2_versioning: "Environment", feature: Feature -): - # Given - identity_override = FeatureState.objects.create( - environment=environment_v2_versioning, identity=identity, feature=feature - ) - - # When - all_feature_states = identity.get_all_feature_states() - - # Then - assert len(all_feature_states) == 1 - assert all_feature_states[0] == identity_override diff --git a/api/tests/unit/environments/identities/test_identities_views.py b/api/tests/unit/environments/identities/test_identities_views.py deleted file mode 100644 index 3040af294297..000000000000 --- a/api/tests/unit/environments/identities/test_identities_views.py +++ /dev/null @@ -1,87 +0,0 @@ -from django.urls import reverse -from rest_framework import status -from rest_framework.permissions import IsAuthenticated - -from environments.identities.views import IdentityViewSet -from environments.permissions.constants import ( - MANAGE_IDENTITIES, - VIEW_IDENTITIES, -) -from environments.permissions.permissions import NestedEnvironmentPermissions - - -def test_user_with_view_identities_permission_can_retrieve_identity( - environment, - identity, - test_user_client, - view_environment_permission, - view_identities_permission, - view_project_permission, - user_environment_permission, - user_project_permission, -): - # Given - - user_environment_permission.permissions.add( - view_environment_permission, view_identities_permission - ) - user_project_permission.permissions.add(view_project_permission) - - url = reverse( - "api-v1:environments:environment-identities-detail", - args=(environment.api_key, identity.id), - ) - - # When - response = test_user_client.get(url) - - # Then - assert response.status_code == status.HTTP_200_OK - - -def test_user_with_view_environment_permission_can_not_list_identities( - environment, - identity, - test_user_client, - view_environment_permission, - manage_identities_permission, - view_project_permission, - user_environment_permission, - user_project_permission, -): - # Given - - user_environment_permission.permissions.add(view_environment_permission) - user_project_permission.permissions.add(view_project_permission) - - url = reverse( - "api-v1:environments:environment-identities-list", - args=(environment.api_key,), - ) - - # When - response = test_user_client.get(url) - - # Then - assert response.status_code == status.HTTP_403_FORBIDDEN - - -def test_identity_view_set_get_permissions(): - # Given - view_set = IdentityViewSet() - - # When - permissions = view_set.get_permissions() - - # Then - assert isinstance(permissions[0], IsAuthenticated) - assert isinstance(permissions[1], NestedEnvironmentPermissions) - - assert permissions[1].action_permission_map == { - "list": VIEW_IDENTITIES, - "retrieve": VIEW_IDENTITIES, - "create": MANAGE_IDENTITIES, - "update": MANAGE_IDENTITIES, - "partial_update": MANAGE_IDENTITIES, - "destroy": MANAGE_IDENTITIES, - } diff --git a/api/tests/unit/environments/identities/test_identities_feature_states_views.py b/api/tests/unit/environments/identities/test_unit_identities_feature_states_views.py similarity index 100% rename from api/tests/unit/environments/identities/test_identities_feature_states_views.py rename to api/tests/unit/environments/identities/test_unit_identities_feature_states_views.py diff --git a/api/environments/identities/tests/test_helpers.py b/api/tests/unit/environments/identities/test_unit_identities_helpers.py similarity index 100% rename from api/environments/identities/tests/test_helpers.py rename to api/tests/unit/environments/identities/test_unit_identities_helpers.py diff --git a/api/environments/identities/tests/test_models.py b/api/tests/unit/environments/identities/test_unit_identities_models.py similarity index 93% rename from api/environments/identities/tests/test_models.py rename to api/tests/unit/environments/identities/test_unit_identities_models.py index 9e0691b22f12..dab2f7bf38f0 100644 --- a/api/environments/identities/tests/test_models.py +++ b/api/tests/unit/environments/identities/test_unit_identities_models.py @@ -929,3 +929,83 @@ def test_get_all_feature_hide_disabled_flags( # Then assert bool(identity_flags) == disabled_flag_returned + + +def test_identity_get_all_feature_states_gets_latest_committed_version(environment): + # Given + identity = Identity.objects.create(identifier="identity", environment=environment) + + feature = Feature.objects.create( + name="versioned_feature", + project=environment.project, + default_enabled=False, + initial_value="v1", + ) + + now = timezone.now() + + # creating the feature above will have created a feature state with version=1, + # now we create 2 more versions... + # one of which is live... + feature_state_v2 = FeatureState.objects.create( + feature=feature, + version=2, + live_from=now, + enabled=True, + environment=environment, + ) + feature_state_v2.feature_state_value.string_value = "v2" + feature_state_v2.feature_state_value.save() + + # and one which isn't + not_live_feature_state = FeatureState.objects.create( + feature=feature, + version=None, + live_from=None, + enabled=False, + environment=environment, + ) + not_live_feature_state.feature_state_value.string_value = "v3" + not_live_feature_state.feature_state_value.save() + + # When + identity_feature_states = identity.get_all_feature_states() + + # Then + identity_feature_state = next( + filter(lambda fs: fs.feature == feature, identity_feature_states) + ) + assert identity_feature_state.get_feature_state_value() == "v2" + + +def test_get_hash_key_with_use_identity_composite_key_for_hashing_enabled( + identity: Identity, +): + assert ( + identity.get_hash_key(use_identity_composite_key_for_hashing=True) + == f"{identity.environment.api_key}_{identity.identifier}" + ) + + +def test_get_hash_key_with_use_identity_composite_key_for_hashing_disabled( + identity: Identity, +): + assert identity.get_hash_key(use_identity_composite_key_for_hashing=False) == str( + identity.id + ) + + +def test_identity_get_all_feature_states__returns_identity_override__when_v2_feature_versioning_enabled( + identity: Identity, environment_v2_versioning: "Environment", feature: Feature +): + # Given + identity_override = FeatureState.objects.create( + environment=environment_v2_versioning, identity=identity, feature=feature + ) + + # When + all_feature_states = identity.get_all_feature_states() + + # Then + assert len(all_feature_states) == 1 + assert all_feature_states[0] == identity_override diff --git a/api/environments/identities/tests/test_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py similarity index 93% rename from api/environments/identities/tests/test_views.py rename to api/tests/unit/environments/identities/test_unit_identities_views.py index bbec39ed38b4..75e39c4c8706 100644 --- a/api/environments/identities/tests/test_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -10,6 +10,7 @@ from django.utils import timezone from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import status +from rest_framework.permissions import IsAuthenticated from rest_framework.test import APIClient, APITestCase from environments.identities.helpers import ( @@ -17,7 +18,13 @@ ) from environments.identities.models import Identity from environments.identities.traits.models import Trait +from environments.identities.views import IdentityViewSet from environments.models import Environment, EnvironmentAPIKey +from environments.permissions.constants import ( + MANAGE_IDENTITIES, + VIEW_IDENTITIES, +) +from environments.permissions.permissions import NestedEnvironmentPermissions from features.models import Feature, FeatureSegment, FeatureState from integrations.amplitude.models import AmplitudeConfiguration from organisations.models import Organisation, OrganisationRole @@ -1037,3 +1044,80 @@ def test_post_identities__server_key_only_feature__server_key_auth__return_expec # Then assert response.status_code == status.HTTP_200_OK assert response.json()["flags"] + + +def test_user_with_view_identities_permission_can_retrieve_identity( + environment, + identity, + test_user_client, + view_environment_permission, + view_identities_permission, + view_project_permission, + user_environment_permission, + user_project_permission, +): + # Given + + user_environment_permission.permissions.add( + view_environment_permission, view_identities_permission + ) + user_project_permission.permissions.add(view_project_permission) + + url = reverse( + "api-v1:environments:environment-identities-detail", + args=(environment.api_key, identity.id), + ) + + # When + response = test_user_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + + +def test_user_with_view_environment_permission_can_not_list_identities( + environment, + identity, + test_user_client, + view_environment_permission, + manage_identities_permission, + view_project_permission, + user_environment_permission, + user_project_permission, +): + # Given + + user_environment_permission.permissions.add(view_environment_permission) + user_project_permission.permissions.add(view_project_permission) + + url = reverse( + "api-v1:environments:environment-identities-list", + args=(environment.api_key,), + ) + + # When + response = test_user_client.get(url) + + # Then + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_identity_view_set_get_permissions(): + # Given + view_set = IdentityViewSet() + + # When + permissions = view_set.get_permissions() + + # Then + assert isinstance(permissions[0], IsAuthenticated) + assert isinstance(permissions[1], NestedEnvironmentPermissions) + + assert permissions[1].action_permission_map == { + "list": VIEW_IDENTITIES, + "retrieve": VIEW_IDENTITIES, + "create": MANAGE_IDENTITIES, + "update": MANAGE_IDENTITIES, + "partial_update": MANAGE_IDENTITIES, + "destroy": MANAGE_IDENTITIES, + } diff --git a/api/tests/unit/environments/identities/traits/test_traits_views.py b/api/tests/unit/environments/identities/traits/test_traits_views.py index 208b78d7b33f..0d2d6b51876f 100644 --- a/api/tests/unit/environments/identities/traits/test_traits_views.py +++ b/api/tests/unit/environments/identities/traits/test_traits_views.py @@ -1,9 +1,22 @@ +import json +from unittest import mock +from unittest.case import TestCase + +import pytest +from core.constants import INTEGER, STRING +from django.test import override_settings from django.urls import reverse from rest_framework import status from rest_framework.permissions import IsAuthenticated +from rest_framework.test import APIClient, APITestCase +from environments.identities.models import Identity +from environments.identities.traits.constants import ( + TRAIT_STRING_VALUE_MAX_LENGTH, +) from environments.identities.traits.models import Trait from environments.identities.traits.views import TraitViewSet +from environments.models import Environment, EnvironmentAPIKey from environments.permissions.constants import ( MANAGE_IDENTITIES, VIEW_ENVIRONMENT, @@ -11,9 +24,793 @@ ) from environments.permissions.models import UserEnvironmentPermission from environments.permissions.permissions import NestedEnvironmentPermissions +from organisations.models import Organisation, OrganisationRole from permissions.models import PermissionModel -from projects.models import UserProjectPermission +from projects.models import Project, UserProjectPermission from projects.permissions import VIEW_PROJECT +from util.tests import Helper + + +class SDKTraitsTest(APITestCase): + JSON = "application/json" + + def setUp(self) -> None: + self.organisation = Organisation.objects.create(name="Test organisation") + project = Project.objects.create( + name="Test project", organisation=self.organisation, enable_dynamo_db=True + ) + self.environment = Environment.objects.create( + name="Test environment", project=project + ) + self.identity = Identity.objects.create( + identifier="test-user", environment=self.environment + ) + self.client.credentials(HTTP_X_ENVIRONMENT_KEY=self.environment.api_key) + self.trait_key = "trait_key" + self.trait_value = "trait_value" + + def test_can_set_trait_for_an_identity(self): + # Given + url = reverse("api-v1:sdk-traits-list") + + # When + res = self.client.post( + url, data=self._generate_json_trait_data(), content_type=self.JSON + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert Trait.objects.filter( + identity=self.identity, trait_key=self.trait_key + ).exists() + + def test_cannot_set_trait_for_an_identity_for_organisations_without_persistence( + self, + ): + # Given + url = reverse("api-v1:sdk-traits-list") + + # an organisation that is configured to not store traits + self.organisation.persist_trait_data = False + self.organisation.save() + + # When + response = self.client.post( + url, data=self._generate_json_trait_data(), content_type=self.JSON + ) + + # Then + # the request fails + assert response.status_code == status.HTTP_403_FORBIDDEN + response_json = response.json() + assert response_json["detail"] == ( + "Organisation is not authorised to store traits." + ) + + # and no traits are stored + assert Trait.objects.count() == 0 + + def test_can_set_trait_with_boolean_value_for_an_identity(self): + # Given + url = reverse("api-v1:sdk-traits-list") + trait_value = True + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(trait_value=trait_value), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert ( + Trait.objects.get( + identity=self.identity, trait_key=self.trait_key + ).get_trait_value() + == trait_value + ) + + def test_can_set_trait_with_identity_value_for_an_identity(self): + # Given + url = reverse("api-v1:sdk-traits-list") + trait_value = 12 + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(trait_value=trait_value), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert ( + Trait.objects.get( + identity=self.identity, trait_key=self.trait_key + ).get_trait_value() + == trait_value + ) + + def test_can_set_trait_with_float_value_for_an_identity(self): + # Given + url = reverse("api-v1:sdk-traits-list") + float_trait_value = 10.5 + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(trait_value=float_trait_value), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert ( + Trait.objects.get( + identity=self.identity, trait_key=self.trait_key + ).get_trait_value() + == float_trait_value + ) + + def test_add_trait_creates_identity_if_it_doesnt_exist(self): + # Given + url = reverse("api-v1:sdk-traits-list") + identifier = "new-identity" + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(identifier=identifier), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert Identity.objects.filter( + identifier=identifier, environment=self.environment + ).exists() + + # and + assert Trait.objects.filter( + identity__identifier=identifier, trait_key=self.trait_key + ).exists() + + def test_trait_is_updated_if_already_exists(self): + # Given + url = reverse("api-v1:sdk-traits-list") + trait = Trait.objects.create( + trait_key=self.trait_key, + value_type=STRING, + string_value=self.trait_value, + identity=self.identity, + ) + new_value = "Some new value" + + # When + self.client.post( + url, + data=self._generate_json_trait_data(trait_value=new_value), + content_type=self.JSON, + ) + + # Then + trait.refresh_from_db() + assert trait.get_trait_value() == new_value + + def test_increment_value_increments_trait_value_if_value_positive_integer(self): + # Given + initial_value = 2 + increment_by = 2 + + url = reverse("api-v1:sdk-traits-increment-value") + trait = Trait.objects.create( + identity=self.identity, + trait_key=self.trait_key, + value_type=INTEGER, + integer_value=initial_value, + ) + data = { + "trait_key": self.trait_key, + "identifier": self.identity.identifier, + "increment_by": increment_by, + } + + # When + self.client.post(url, data=data) + + # Then + trait.refresh_from_db() + assert trait.get_trait_value() == initial_value + increment_by + + def test_increment_value_decrements_trait_value_if_value_negative_integer(self): + # Given + initial_value = 2 + increment_by = -2 + + url = reverse("api-v1:sdk-traits-increment-value") + trait = Trait.objects.create( + identity=self.identity, + trait_key=self.trait_key, + value_type=INTEGER, + integer_value=initial_value, + ) + data = { + "trait_key": self.trait_key, + "identifier": self.identity.identifier, + "increment_by": increment_by, + } + + # When + self.client.post(url, data=data) + + # Then + trait.refresh_from_db() + assert trait.get_trait_value() == initial_value + increment_by + + def test_increment_value_initialises_trait_with_a_value_of_zero_if_it_doesnt_exist( + self, + ): + # Given + increment_by = 1 + + url = reverse("api-v1:sdk-traits-increment-value") + data = { + "trait_key": self.trait_key, + "identifier": self.identity.identifier, + "increment_by": increment_by, + } + + # When + self.client.post(url, data=data) + + # Then + trait = Trait.objects.get(trait_key=self.trait_key, identity=self.identity) + assert trait.get_trait_value() == increment_by + + def test_increment_value_returns_400_if_trait_value_not_integer(self): + # Given + url = reverse("api-v1:sdk-traits-increment-value") + Trait.objects.create( + identity=self.identity, + trait_key=self.trait_key, + value_type=STRING, + string_value="str", + ) + data = { + "trait_key": self.trait_key, + "identifier": self.identity.identifier, + "increment_by": 2, + } + + # When + res = self.client.post(url, data=data) + + # Then + assert res.status_code == status.HTTP_400_BAD_REQUEST + + def test_set_trait_with_too_long_string_value_returns_400(self): + # Given + url = reverse("api-v1:sdk-traits-list") + trait_value = "t" * (TRAIT_STRING_VALUE_MAX_LENGTH + 1) + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(trait_value=trait_value), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_400_BAD_REQUEST + assert ( + f"Value string is too long. Must be less than {TRAIT_STRING_VALUE_MAX_LENGTH} character" + == res.json()["trait_value"][0] + ) + + def test_can_set_trait_with_bad_value_for_an_identity(self): + # Given + url = reverse("api-v1:sdk-traits-list") + bad_trait_value = {"foo": "bar"} + + # When + res = self.client.post( + url, + data=self._generate_json_trait_data(trait_value=bad_trait_value), + content_type=self.JSON, + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + # and + assert Trait.objects.get( + identity=self.identity, trait_key=self.trait_key + ).get_trait_value() == str(bad_trait_value) + + def test_bulk_create_traits(self): + # Given + num_traits = 20 + url = reverse("api-v1:sdk-traits-bulk-create") + traits = [ + self._generate_trait_data(trait_key=f"trait_{i}", identifier="user_{i}") + for i in range(num_traits) + ] + identifiers = [trait["identity"]["identifier"] for trait in traits] + + # When + response = self.client.put( + url, data=json.dumps(traits), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + assert ( + Trait.objects.filter(identity__identifier__in=identifiers).count() + == num_traits + ) + + def test_bulk_create_traits_when_bad_trait_value_sent_then_trait_value_stringified( + self, + ): + # Given + num_traits = 5 + url = reverse("api-v1:sdk-traits-bulk-create") + traits = [ + self._generate_trait_data(trait_key=f"trait_{i}") for i in range(num_traits) + ] + + # add some bad data to test + bad_trait_key = "trait_999" + bad_trait_value = {"foo": "bar"} + traits.append( + { + "trait_value": bad_trait_value, + "trait_key": bad_trait_key, + "identity": {"identifier": self.identity.identifier}, + } + ) + + # When + response = self.client.put( + url, data=json.dumps(traits), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + assert Trait.objects.filter(identity=self.identity).count() == num_traits + 1 + + # and + assert Trait.objects.get( + identity=self.identity, trait_key=bad_trait_key + ).get_trait_value() == str(bad_trait_value) + + def test_sending_null_value_in_bulk_create_deletes_trait_for_identity(self): + # Given + url = reverse("api-v1:sdk-traits-bulk-create") + trait_to_delete = Trait.objects.create( + trait_key=self.trait_key, + value_type=STRING, + string_value=self.trait_value, + identity=self.identity, + ) + trait_key_to_keep = "another_trait_key" + trait_to_keep = Trait.objects.create( + trait_key=trait_key_to_keep, + value_type=STRING, + string_value="value is irrelevant", + identity=self.identity, + ) + data = [ + { + "identity": {"identifier": self.identity.identifier}, + "trait_key": self.trait_key, + "trait_value": None, + } + ] + + # When + response = self.client.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + # the request is successful + assert response.status_code == status.HTTP_200_OK + + # and the trait is deleted + assert not Trait.objects.filter(id=trait_to_delete.id).exists() + + # but the trait missing from the request is left untouched + assert Trait.objects.filter(id=trait_to_keep.id).exists() + + def test_bulk_create_traits_when_float_value_sent_then_trait_value_correct(self): + # Given + url = reverse("api-v1:sdk-traits-bulk-create") + traits = [] + + # add float value trait + float_trait_key = "float_key_999" + float_trait_value = 45.88 + traits.append( + { + "trait_value": float_trait_value, + "trait_key": float_trait_key, + "identity": {"identifier": self.identity.identifier}, + } + ) + + # When + response = self.client.put( + url, data=json.dumps(traits), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + assert Trait.objects.filter(identity=self.identity).count() == 1 + + # and + assert ( + Trait.objects.get( + identity=self.identity, trait_key=float_trait_key + ).get_trait_value() + == float_trait_value + ) + + @override_settings(EDGE_API_URL="http://localhost") + @mock.patch("environments.identities.traits.views.forward_trait_request") + def test_post_trait_calls_forward_trait_request_with_correct_arguments( + self, mocked_forward_trait_request + ): + # Given + url = reverse("api-v1:sdk-traits-list") + data = self._generate_json_trait_data() + + # When + self.client.post(url, data=data, content_type=self.JSON) + + # Then + args, kwargs = mocked_forward_trait_request.delay.call_args_list[0] + assert args == () + assert kwargs["args"][0] == "POST" + assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key + assert kwargs["args"][2] == self.environment.project.id + assert kwargs["args"][3] == json.loads(data) + + @override_settings(EDGE_API_URL="http://localhost") + @mock.patch("environments.identities.traits.views.forward_trait_request") + def test_increment_value_calls_forward_trait_request_with_correct_arguments( + self, mocked_forward_trait_request + ): + # Given + url = reverse("api-v1:sdk-traits-increment-value") + data = { + "trait_key": self.trait_key, + "identifier": self.identity.identifier, + "increment_by": 1, + } + + # When + self.client.post(url, data=data) + + # Then + args, kwargs = mocked_forward_trait_request.delay.call_args_list[0] + assert args == () + assert kwargs["args"][0] == "POST" + assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key + assert kwargs["args"][2] == self.environment.project.id + + # and the structure of payload was correct + assert kwargs["args"][3]["identity"]["identifier"] == data["identifier"] + assert kwargs["args"][3]["trait_key"] == data["trait_key"] + assert kwargs["args"][3]["trait_value"] + + @override_settings(EDGE_API_URL="http://localhost") + @mock.patch("environments.identities.traits.views.forward_trait_requests") + def test_bulk_create_traits_calls_forward_trait_request_with_correct_arguments( + self, mocked_forward_trait_requests + ): + # Given + url = reverse("api-v1:sdk-traits-bulk-create") + request_data = [ + { + "identity": {"identifier": "test_user_123"}, + "trait_key": "key", + "trait_value": "value", + }, + { + "identity": {"identifier": "test_user_123"}, + "trait_key": "key1", + "trait_value": "value1", + }, + ] + + # When + self.client.put( + url, data=json.dumps(request_data), content_type="application/json" + ) + + # Then + + # Then + args, kwargs = mocked_forward_trait_requests.delay.call_args_list[0] + assert args == () + assert kwargs["args"][0] == "PUT" + assert kwargs["args"][1].get("X-Environment-Key") == self.environment.api_key + assert kwargs["args"][2] == self.environment.project.id + assert kwargs["args"][3] == request_data + + def test_create_trait_returns_403_if_client_cannot_set_traits(self): + # Given + url = reverse("api-v1:sdk-traits-list") + data = { + "identity": {"identifier": self.identity.identifier}, + "trait_key": "foo", + "trait_value": "bar", + } + + self.environment.allow_client_traits = False + self.environment.save() + + # When + response = self.client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_server_key_can_create_trait_if_not_allow_client_traits(self): + # Given + url = reverse("api-v1:sdk-traits-list") + data = { + "identity": {"identifier": self.identity.identifier}, + "trait_key": "foo", + "trait_value": "bar", + } + + server_api_key = EnvironmentAPIKey.objects.create(environment=self.environment) + self.client.credentials(HTTP_X_ENVIRONMENT_KEY=server_api_key.key) + + self.environment.allow_client_traits = False + self.environment.save() + + # When + response = self.client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + + def test_bulk_create_traits_returns_403_if_client_cannot_set_traits(self): + # Given + url = reverse("api-v1:sdk-traits-bulk-create") + data = [ + { + "identity": {"identifier": self.identity.identifier}, + "trait_key": "foo", + "trait_value": "bar", + } + ] + + self.environment.allow_client_traits = False + self.environment.save() + + # When + response = self.client.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_server_key_can_bulk_create_traits_if_not_allow_client_traits(self): + # Given + url = reverse("api-v1:sdk-traits-bulk-create") + data = [ + { + "identity": {"identifier": self.identity.identifier}, + "trait_key": "foo", + "trait_value": "bar", + } + ] + + server_api_key = EnvironmentAPIKey.objects.create(environment=self.environment) + self.client.credentials(HTTP_X_ENVIRONMENT_KEY=server_api_key.key) + + self.environment.allow_client_traits = False + self.environment.save() + + # When + response = self.client.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + + def _generate_trait_data(self, identifier=None, trait_key=None, trait_value=None): + identifier = identifier or self.identity.identifier + trait_key = trait_key or self.trait_key + trait_value = trait_value or self.trait_value + + return { + "identity": {"identifier": identifier}, + "trait_key": trait_key, + "trait_value": trait_value, + } + + def _generate_json_trait_data( + self, identifier=None, trait_key=None, trait_value=None + ): + return json.dumps(self._generate_trait_data(identifier, trait_key, trait_value)) + + +@pytest.mark.django_db +class TraitViewSetTestCase(TestCase): + def setUp(self) -> None: + self.client = APIClient() + user = Helper.create_ffadminuser() + self.client.force_authenticate(user=user) + + organisation = Organisation.objects.create(name="Test org") + user.add_organisation(organisation, OrganisationRole.ADMIN) + + self.project = Project.objects.create( + name="Test project", organisation=organisation + ) + self.environment = Environment.objects.create( + name="Test environment", project=self.project + ) + self.identity = Identity.objects.create( + identifier="test-user", environment=self.environment + ) + + def test_delete_trait_only_deletes_single_trait_if_query_param_not_provided(self): + # Given + trait_key = "trait_key" + trait_value = "trait_value" + identity_2 = Identity.objects.create( + identifier="test-user-2", environment=self.environment + ) + + trait = Trait.objects.create( + identity=self.identity, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + trait_2 = Trait.objects.create( + identity=identity_2, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + + url = reverse( + "api-v1:environments:identities-traits-detail", + args=[self.environment.api_key, self.identity.id, trait.id], + ) + + # When + self.client.delete(url) + + # Then + assert not Trait.objects.filter(pk=trait.id).exists() + + # and + assert Trait.objects.filter(pk=trait_2.id).exists() + + def test_delete_trait_deletes_all_traits_if_query_param_provided(self): + # Given + trait_key = "trait_key" + trait_value = "trait_value" + identity_2 = Identity.objects.create( + identifier="test-user-2", environment=self.environment + ) + + trait = Trait.objects.create( + identity=self.identity, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + trait_2 = Trait.objects.create( + identity=identity_2, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + + base_url = reverse( + "api-v1:environments:identities-traits-detail", + args=[self.environment.api_key, self.identity.id, trait.id], + ) + url = base_url + "?deleteAllMatchingTraits=true" + + # When + self.client.delete(url) + + # Then + assert not Trait.objects.filter(pk=trait.id).exists() + + # and + assert not Trait.objects.filter(pk=trait_2.id).exists() + + def test_delete_trait_only_deletes_traits_in_current_environment(self): + # Given + environment_2 = Environment.objects.create( + name="Test environment", project=self.project + ) + trait_key = "trait_key" + trait_value = "trait_value" + identity_2 = Identity.objects.create( + identifier="test-user-2", environment=environment_2 + ) + + trait = Trait.objects.create( + identity=self.identity, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + trait_2 = Trait.objects.create( + identity=identity_2, + trait_key=trait_key, + value_type=STRING, + string_value=trait_value, + ) + + base_url = reverse( + "api-v1:environments:identities-traits-detail", + args=[self.environment.api_key, self.identity.id, trait.id], + ) + url = base_url + "?deleteAllMatchingTraits=true" + + # When + self.client.delete(url) + + # Then + assert not Trait.objects.filter(pk=trait.id).exists() + + # and + assert Trait.objects.filter(pk=trait_2.id).exists() + + +def test_set_trait_for_an_identity_is_not_throttled_by_user_throttle( + settings, identity, environment, api_client +): + # Given + settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + + url = reverse("api-v1:sdk-traits-list") + data = { + "identity": {"identifier": identity.identifier}, + "trait_key": "key", + "trait_value": "value", + } + + # When + for _ in range(10): + res = api_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert res.status_code == status.HTTP_200_OK def test_user_with_manage_identities_permission_can_add_trait_for_identity( diff --git a/api/environments/identities/traits/tests/test_models.py b/api/tests/unit/environments/identities/traits/test_unit_traits_models.py similarity index 100% rename from api/environments/identities/traits/tests/test_models.py rename to api/tests/unit/environments/identities/traits/test_unit_traits_models.py diff --git a/api/environments/management/commands/tests/test_migrate_to_edge.py b/api/tests/unit/environments/management/commands/test_unit_environments_management_commands_migrate_to_edge.py similarity index 100% rename from api/environments/management/commands/tests/test_migrate_to_edge.py rename to api/tests/unit/environments/management/commands/test_unit_environments_management_commands_migrate_to_edge.py diff --git a/api/environments/permissions/tests/test_permissions.py b/api/tests/unit/environments/permissions/test_unit_environments_permissions.py similarity index 100% rename from api/environments/permissions/tests/test_permissions.py rename to api/tests/unit/environments/permissions/test_unit_environments_permissions.py diff --git a/api/environments/permissions/tests/test_views.py b/api/tests/unit/environments/permissions/test_unit_environments_views.py similarity index 100% rename from api/environments/permissions/tests/test_views.py rename to api/tests/unit/environments/permissions/test_unit_environments_views.py diff --git a/api/environments/tests/test_authentication.py b/api/tests/unit/environments/test_unit_environments_authentication.py similarity index 100% rename from api/environments/tests/test_authentication.py rename to api/tests/unit/environments/test_unit_environments_authentication.py diff --git a/api/tests/unit/environments/test_environments_feature_states_views.py b/api/tests/unit/environments/test_unit_environments_feature_states_views.py similarity index 100% rename from api/tests/unit/environments/test_environments_feature_states_views.py rename to api/tests/unit/environments/test_unit_environments_feature_states_views.py diff --git a/api/tests/unit/environments/test_migrations.py b/api/tests/unit/environments/test_unit_environments_migrations.py similarity index 100% rename from api/tests/unit/environments/test_migrations.py rename to api/tests/unit/environments/test_unit_environments_migrations.py diff --git a/api/tests/unit/environments/test_unit_environments_models.py b/api/tests/unit/environments/test_unit_environments_models.py index 2e0ced78934e..6a00da4c6896 100644 --- a/api/tests/unit/environments/test_unit_environments_models.py +++ b/api/tests/unit/environments/test_unit_environments_models.py @@ -1,14 +1,31 @@ import typing +from copy import copy +from datetime import timedelta +from unittest import mock from unittest.mock import MagicMock import pytest +from core.constants import STRING from core.request_origin import RequestOrigin +from django.test import TestCase, override_settings +from django.utils import timezone from pytest_django.asserts import assertQuerysetEqual as assert_queryset_equal -from environments.models import Environment, EnvironmentAPIKey, Webhook +from audit.models import AuditLog +from audit.related_object_type import RelatedObjectType +from environments.identities.models import Identity +from environments.models import ( + Environment, + EnvironmentAPIKey, + Webhook, + environment_cache, +) +from features.feature_types import MULTIVARIATE from features.models import Feature, FeatureState +from features.multivariate.models import MultivariateFeatureOption from features.versioning.models import EnvironmentFeatureVersion -from organisations.models import OrganisationRole +from organisations.models import Organisation, OrganisationRole +from projects.models import Project from segments.models import Segment from util.mappers import map_environment_to_environment_document @@ -16,8 +33,367 @@ from django.db.models import Model from features.workflows.core.models import ChangeRequest - from organisations.models import Organisation - from projects.models import Project + + +@pytest.mark.django_db +class EnvironmentTestCase(TestCase): + def setUp(self): + self.organisation = Organisation.objects.create(name="Test Org") + self.project = Project.objects.create( + name="Test Project", organisation=self.organisation + ) + self.feature = Feature.objects.create(name="Test Feature", project=self.project) + # The environment is initialised in a non-saved state as we want to test the save + # functionality. + self.environment = Environment(name="Test Environment", project=self.project) + + def test_environment_should_be_created_with_feature_states(self): + # Given - set up data + + # When + self.environment.save() + + # Then + feature_states = FeatureState.objects.filter(environment=self.environment) + assert hasattr(self.environment, "api_key") + assert feature_states.count() == 1 + + def test_on_creation_save_feature_states_get_created(self): + # These should be no feature states before saving + self.assertEqual(FeatureState.objects.count(), 0) + + self.environment.save() + + # On the first save a new feature state should be created + self.assertEqual(FeatureState.objects.count(), 1) + + def test_on_update_save_feature_states_get_updated_not_created(self): + self.environment.save() + + self.feature.default_enabled = True + self.feature.save() + self.environment.save() + + self.assertEqual(FeatureState.objects.count(), 1) + + def test_on_creation_save_feature_is_created_with_the_correct_default(self): + self.environment.save() + self.assertFalse(FeatureState.objects.get().enabled) + + def test_clone_does_not_modify_the_original_instance(self): + # Given + self.environment.save() + + # When + clone = self.environment.clone(name="Cloned env") + + # Then + self.assertNotEqual(clone.name, self.environment.name) + self.assertNotEqual(clone.api_key, self.environment.api_key) + + def test_clone_save_creates_feature_states(self): + # Given + self.environment.save() + + # When + clone = self.environment.clone(name="Cloned env") + + # Then + feature_states = FeatureState.objects.filter(environment=clone) + assert feature_states.count() == 1 + + def test_clone_does_not_modify_source_feature_state(self): + # Given + self.environment.save() + source_feature_state_before_clone = FeatureState.objects.filter( + environment=self.environment + ).first() + + # When + self.environment.clone(name="Cloned env") + source_feature_state_after_clone = FeatureState.objects.filter( + environment=self.environment + ).first() + + # Then + assert source_feature_state_before_clone == source_feature_state_after_clone + + def test_clone_does_not_create_identity(self): + # Given + self.environment.save() + Identity.objects.create( + environment=self.environment, identifier="test_identity" + ) + # When + clone = self.environment.clone(name="Cloned env") + + # Then + assert clone.identities.count() == 0 + + def test_clone_clones_the_feature_states(self): + # Given + self.environment.save() + + # Enable the feature in the source environment + self.environment.feature_states.update(enabled=True) + + # When + clone = self.environment.clone(name="Cloned env") + + # Then + assert clone.feature_states.first().enabled is True + + def test_clone_clones_multivariate_feature_state_values(self): + # Given + self.environment.save() + + mv_feature = Feature.objects.create( + type=MULTIVARIATE, + name="mv_feature", + initial_value="foo", + project=self.project, + ) + variant_1 = MultivariateFeatureOption.objects.create( + feature=mv_feature, + default_percentage_allocation=10, + type=STRING, + string_value="bar", + ) + + # When + clone = self.environment.clone(name="Cloned env") + + # Then + cloned_mv_feature_state = clone.feature_states.get(feature=mv_feature) + assert cloned_mv_feature_state.multivariate_feature_state_values.count() == 1 + + original_mv_fs_value = FeatureState.objects.get( + environment=self.environment, feature=mv_feature + ).multivariate_feature_state_values.first() + cloned_mv_fs_value = ( + cloned_mv_feature_state.multivariate_feature_state_values.first() + ) + + assert original_mv_fs_value != cloned_mv_fs_value + assert ( + original_mv_fs_value.multivariate_feature_option + == cloned_mv_fs_value.multivariate_feature_option + == variant_1 + ) + assert ( + original_mv_fs_value.percentage_allocation + == cloned_mv_fs_value.percentage_allocation + == 10 + ) + + @mock.patch("environments.models.environment_cache") + def test_get_from_cache_stores_environment_in_cache_on_success(self, mock_cache): + # Given + self.environment.save() + mock_cache.get.return_value = None + + # When + environment = Environment.get_from_cache(self.environment.api_key) + + # Then + assert environment == self.environment + mock_cache.set.assert_called_with( + self.environment.api_key, self.environment, timeout=60 + ) + + def test_get_from_cache_returns_None_if_no_matching_environment(self): + # Given + api_key = "no-matching-env" + + # When + env = Environment.get_from_cache(api_key) + + # Then + assert env is None + + def test_get_from_cache_accepts_environment_api_key_model_key(self): + # Given + self.environment.save() + api_key = EnvironmentAPIKey.objects.create( + name="Some key", environment=self.environment + ) + + # When + environment_from_cache = Environment.get_from_cache(api_key=api_key.key) + + # Then + assert environment_from_cache == self.environment + + def test_get_from_cache_with_null_environment_key_returns_null(self): + # Given + self.environment.save() + + # When + environment = Environment.get_from_cache(None) + + # Then + assert environment is None + + @override_settings( + CACHE_BAD_ENVIRONMENTS_SECONDS=60, CACHE_BAD_ENVIRONMENTS_AFTER_FAILURES=1 + ) + def test_get_from_cache_does_not_hit_database_if_api_key_in_bad_env_cache(self): + # Given + api_key = "bad-key" + + # When + with self.assertNumQueries(1): + [Environment.get_from_cache(api_key) for _ in range(10)] + + +def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key( + environment, +): + assert ( + EnvironmentAPIKey.objects.create( + environment=environment, + key="ser.random_key", + name="test_key", + ).is_valid + is True + ) + + +def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key_with_expired_date_in_future( + environment, +): + assert ( + EnvironmentAPIKey.objects.create( + environment=environment, + key="ser.random_key", + name="test_key", + expires_at=timezone.now() + timedelta(days=5), + ).is_valid + is True + ) + + +def test_environment_api_key_model_is_valid_is_false_for_expired_active_key( + environment, +): + assert ( + EnvironmentAPIKey.objects.create( + environment=environment, + key="ser.random_key", + name="test_key", + expires_at=timezone.now() - timedelta(seconds=1), + ).is_valid + is False + ) + + +def test_environment_api_key_model_is_valid_is_false_for_non_expired_inactive_key( + environment, +): + assert ( + EnvironmentAPIKey.objects.create( + environment=environment, key="ser.random_key", name="test_key", active=False + ).is_valid + is False + ) + + +def test_existence_of_multiple_environment_api_keys_does_not_break_get_from_cache( + environment, +): + # Given + environment_api_keys = [ + EnvironmentAPIKey.objects.create(environment=environment, name=f"test_key_{i}") + for i in range(2) + ] + + # When + retrieved_environments = [ + Environment.get_from_cache(environment.api_key), + *[ + Environment.get_from_cache(environment_api_key.key) + for environment_api_key in environment_api_keys + ], + ] + + # Then + assert all( + retrieved_environment == environment + for retrieved_environment in retrieved_environments + ) + + +def test_get_from_cache_sets_the_cache_correctly_with_environment_api_key( + environment, environment_api_key, mocker +): + # When + returned_environment = Environment.get_from_cache(environment_api_key.key) + + # Then + assert returned_environment == environment + + # and + assert environment == environment_cache.get(environment_api_key.key) + + +def test_updated_at_gets_updated_when_environment_audit_log_created(environment): + # When + audit_log = AuditLog.objects.create( + environment=environment, project=environment.project, log="random_audit_log" + ) + + # Then + environment.refresh_from_db() + assert environment.updated_at == audit_log.created_date + + +def test_updated_at_gets_updated_when_project_audit_log_created(environment): + # When + audit_log = AuditLog.objects.create( + project=environment.project, log="random_audit_log" + ) + environment.refresh_from_db() + # Then + assert environment.updated_at == audit_log.created_date + + +def test_change_request_audit_logs_does_not_update_updated_at(environment): + # Given + updated_at_before_audit_log = environment.updated_at + + # When + audit_log = AuditLog.objects.create( + environment=environment, + log="random_test", + related_object_type=RelatedObjectType.CHANGE_REQUEST.name, + ) + + # Then + assert environment.updated_at == updated_at_before_audit_log + assert environment.updated_at != audit_log.created_date + + +def test_save_environment_clears_environment_cache(mocker, project): + # Given + mock_environment_cache = mocker.patch("environments.models.environment_cache") + environment = Environment.objects.create(name="test environment", project=project) + + # perform an update of the name to verify basic functionality + environment.name = "updated" + environment.save() + + # and update the api key to verify that the original api key is used to clear cache + old_key = copy(environment.api_key) + new_key = "some-new-key" + environment.api_key = new_key + + # When + environment.save() + + # Then + mock_calls = mock_environment_cache.delete.mock_calls + assert len(mock_calls) == 2 + assert mock_calls[0][1][0] == mock_calls[1][1][0] == old_key @pytest.mark.parametrize( diff --git a/api/tests/unit/environments/test_environments_permissions.py b/api/tests/unit/environments/test_unit_environments_permissions.py similarity index 100% rename from api/tests/unit/environments/test_environments_permissions.py rename to api/tests/unit/environments/test_unit_environments_permissions.py diff --git a/api/tests/unit/environments/test_unit_environment_tasks.py b/api/tests/unit/environments/test_unit_environments_tasks.py similarity index 100% rename from api/tests/unit/environments/test_unit_environment_tasks.py rename to api/tests/unit/environments/test_unit_environments_tasks.py diff --git a/api/tests/unit/environments/test_environments_views_sdk_environment.py b/api/tests/unit/environments/test_unit_environments_views_sdk_environment.py similarity index 100% rename from api/tests/unit/environments/test_environments_views_sdk_environment.py rename to api/tests/unit/environments/test_unit_environments_views_sdk_environment.py diff --git a/api/features/feature_segments/tests/test_models.py b/api/tests/unit/features/feature_segments/test_unit_feature_segments_models.py similarity index 100% rename from api/features/feature_segments/tests/test_models.py rename to api/tests/unit/features/feature_segments/test_unit_feature_segments_models.py diff --git a/api/features/feature_segments/tests/test_views.py b/api/tests/unit/features/feature_segments/test_unit_feature_segments_views.py similarity index 100% rename from api/features/feature_segments/tests/test_views.py rename to api/tests/unit/features/feature_segments/test_unit_feature_segments_views.py