diff --git a/api/conftest.py b/api/conftest.py index c468b9c2cd04..543998df2656 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -327,8 +327,13 @@ def project(organisation): @pytest.fixture() -def segment(project): - return Segment.objects.create(name="segment", project=project) +def segment(project: Project): + _segment = Segment.objects.create(name="segment", project=project) + # Deep clone the segment to ensure that any bugs around + # versioning get bubbled up through the test suite. + _segment.deep_clone() + + return _segment @pytest.fixture() diff --git a/api/core/signals.py b/api/core/signals.py index 0b12a2afaa52..d88cb466ecbf 100644 --- a/api/core/signals.py +++ b/api/core/signals.py @@ -1,5 +1,8 @@ +import logging + from core.models import AbstractBaseAuditableModel from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist from django.utils import timezone from simple_history.models import HistoricalRecords @@ -7,6 +10,8 @@ from task_processor.task_run_method import TaskRunMethod from users.models import FFAdminUser +logger = logging.getLogger(__name__) + def create_audit_log_from_historical_record( instance: AbstractBaseAuditableModel, @@ -30,7 +35,18 @@ def create_audit_log_from_historical_record( else None ) - environment, project = instance.get_environment_and_project() + try: + environment, project = instance.get_environment_and_project() + except ObjectDoesNotExist: + logger.warning( + "Unable to create audit log for %s %s. " + "Parent model does not exist - this likely means it was hard deleted.", + instance.related_object_type, + getattr(instance, "id", "uuid"), + exc_info=True, + ) + return + if project != history_instance.instance and ( (project and project.deleted_at) or (environment and environment.project.deleted_at) diff --git a/api/segments/helpers.py b/api/segments/helpers.py new file mode 100644 index 000000000000..9d5797f512dd --- /dev/null +++ b/api/segments/helpers.py @@ -0,0 +1,16 @@ +class SegmentAuditLogHelper: + def __init__(self) -> None: + self.skip_audit_log = {} + + def should_skip_audit_log(self, segment_id: int) -> None | bool: + return self.skip_audit_log.get(segment_id) + + def set_skip_audit_log(self, segment_id: int) -> None: + self.skip_audit_log[segment_id] = True + + def unset_skip_audit_log(self, segment_id: int) -> None: + if segment_id in self.skip_audit_log: + del self.skip_audit_log[segment_id] + + +segment_audit_log_helper = SegmentAuditLogHelper() diff --git a/api/segments/managers.py b/api/segments/managers.py new file mode 100644 index 000000000000..d1f977ca06ed --- /dev/null +++ b/api/segments/managers.py @@ -0,0 +1,11 @@ +from core.models import SoftDeleteExportableManager +from django.db.models import F + + +class SegmentManager(SoftDeleteExportableManager): + def get_queryset(self): + """ + Returns only the canonical segments, which will always be + the highest version. + """ + return super().get_queryset().filter(id=F("version_of")) diff --git a/api/segments/migrations/0023_add_versioning_to_segments.py b/api/segments/migrations/0023_add_versioning_to_segments.py new file mode 100644 index 000000000000..0e7c8f0d5a03 --- /dev/null +++ b/api/segments/migrations/0023_add_versioning_to_segments.py @@ -0,0 +1,58 @@ +# Generated by Django 3.2.25 on 2024-06-10 15:31 +import os + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("segments", "0022_add_soft_delete_to_segment_rules_and_conditions"), + ] + + operations = [ + migrations.AddField( + model_name="historicalsegment", + name="version", + field=models.IntegerField(null=True), + ), + migrations.AddField( + model_name="historicalsegment", + name="version_of", + field=models.ForeignKey( + blank=True, + db_constraint=False, + null=True, + on_delete=django.db.models.deletion.DO_NOTHING, + related_name="+", + to="segments.segment", + ), + ), + migrations.AddField( + model_name="segment", + name="version", + field=models.IntegerField(null=True), + ), + migrations.AddField( + model_name="segment", + name="version_of", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="versioned_segments", + to="segments.segment", + ), + ), + migrations.RunSQL( + sql=open( + os.path.join( + os.path.dirname(__file__), + "sql", + "0023_add_versioning_to_segments.sql", + ) + ).read(), + reverse_sql=migrations.RunSQL.noop, + ), + ] diff --git a/api/segments/migrations/sql/0023_add_versioning_to_segments.sql b/api/segments/migrations/sql/0023_add_versioning_to_segments.sql new file mode 100644 index 000000000000..3e91d2ad844c --- /dev/null +++ b/api/segments/migrations/sql/0023_add_versioning_to_segments.sql @@ -0,0 +1 @@ +UPDATE segments_segment SET version_of_id = id; diff --git a/api/segments/models.py b/api/segments/models.py index dea6f980a7ea..992f37ea7023 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -1,8 +1,10 @@ import logging import typing +import uuid from copy import deepcopy from core.models import ( + SoftDeleteExportableManager, SoftDeleteExportableModel, abstract_base_auditable_model_factory, ) @@ -10,6 +12,12 @@ from django.contrib.contenttypes.fields import GenericRelation from django.core.exceptions import ValidationError from django.db import models +from django_lifecycle import ( + AFTER_CREATE, + BEFORE_CREATE, + LifecycleModelMixin, + hook, +) from flag_engine.segments import constants from audit.constants import ( @@ -22,10 +30,14 @@ from metadata.models import Metadata from projects.models import Project +from .helpers import segment_audit_log_helper +from .managers import SegmentManager + logger = logging.getLogger(__name__) class Segment( + LifecycleModelMixin, SoftDeleteExportableModel, abstract_base_auditable_model_factory(["uuid"]), ): @@ -45,14 +57,44 @@ class Segment( Feature, on_delete=models.CASCADE, related_name="segments", null=True ) + # This defaults to 1 for newly created segments. + version = models.IntegerField(null=True) + + # The related_name is not useful without specifying all_objects as a manager. + version_of = models.ForeignKey( + "self", + on_delete=models.CASCADE, + related_name="versioned_segments", + null=True, + blank=True, + ) metadata = GenericRelation(Metadata) + # Only serves segments that are the canonical version. + objects = SegmentManager() + + # Includes versioned segments. + all_objects = SoftDeleteExportableManager() + class Meta: ordering = ("id",) # explicit ordering to prevent pagination warnings def __str__(self): return "Segment - %s" % self.name + def get_skip_create_audit_log(self) -> bool: + skip = segment_audit_log_helper.should_skip_audit_log(self.id) + if skip is not None: + return skip + + try: + if self.version_of_id and self.version_of_id != self.id: + return True + except Segment.DoesNotExist: + return True + + return False + @staticmethod def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool: """ @@ -84,6 +126,49 @@ def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool: return False + @hook(BEFORE_CREATE, when="version_of", is_now=None) + def set_default_version_to_one_if_new_segment(self): + if self.version is None: + self.version = 1 + + @hook(AFTER_CREATE, when="version_of", is_now=None) + def set_version_of_to_self_if_none(self): + """ + This allows the segment model to reference all versions of + itself including itself. + """ + segment_audit_log_helper.set_skip_audit_log(self.id) + self.version_of = self + self.save() + segment_audit_log_helper.unset_skip_audit_log(self.id) + + def deep_clone(self) -> "Segment": + cloned_segment = deepcopy(self) + cloned_segment.id = None + cloned_segment.uuid = uuid.uuid4() + cloned_segment.version_of = self + cloned_segment.save() + + segment_audit_log_helper.set_skip_audit_log(self.id) + self.version += 1 + self.save() + segment_audit_log_helper.unset_skip_audit_log(self.id) + + cloned_rules = [] + for rule in self.rules.all(): + cloned_rule = rule.deep_clone(cloned_segment) + cloned_rules.append(cloned_rule) + + cloned_segment.refresh_from_db() + + assert ( + len(self.rules.all()) + == len(cloned_rules) + == len(cloned_segment.rules.all()) + ), "Mismatch during rules creation" + + return cloned_segment + def get_create_log_message(self, history_instance) -> typing.Optional[str]: return SEGMENT_CREATED_MESSAGE % self.name @@ -128,6 +213,10 @@ def __str__(self): str(self.segment) if self.segment else str(self.rule), ) + def get_skip_create_audit_log(self) -> bool: + segment = self.get_segment() + return segment.version_of_id != segment.id + def get_segment(self): """ rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the @@ -136,10 +225,46 @@ def get_segment(self): TODO: denormalise the segment information so that we don't have to make multiple queries here in complex cases """ rule = self - while not rule.segment: + while not rule.segment_id: rule = rule.rule return rule.segment + def deep_clone(self, cloned_segment: Segment) -> "SegmentRule": + if self.rule: + # Since we're expecting a rule that is only belonging to a + # segment, since a rule either belongs to a segment xor belongs + # to a rule, we don't expect there also to be a rule associated. + assert False, "Unexpected rule, expecting segment set not rule" + cloned_rule = deepcopy(self) + cloned_rule.segment = cloned_segment + cloned_rule.uuid = uuid.uuid4() + cloned_rule.id = None + cloned_rule.save() + + # Conditions are only part of the sub-rules. + assert self.conditions.exists() is False + + for sub_rule in self.rules.all(): + if sub_rule.rules.exists(): + assert False, "Expected two layers of rules, not more" + + cloned_sub_rule = deepcopy(sub_rule) + cloned_sub_rule.rule = cloned_rule + cloned_sub_rule.uuid = uuid.uuid4() + cloned_sub_rule.id = None + cloned_sub_rule.save() + + cloned_conditions = [] + for condition in sub_rule.conditions.all(): + cloned_condition = deepcopy(condition) + cloned_condition.rule = cloned_sub_rule + cloned_condition.uuid = uuid.uuid4() + cloned_condition.id = None + cloned_conditions.append(cloned_condition) + Condition.objects.bulk_create(cloned_conditions) + + return cloned_rule + class Condition( SoftDeleteExportableModel, abstract_base_auditable_model_factory(["uuid"]) @@ -188,6 +313,10 @@ def __str__(self): self.value, ) + def get_skip_create_audit_log(self) -> bool: + segment = self.rule.get_segment() + return segment.version_of_id != segment.id + def get_update_log_message(self, history_instance) -> typing.Optional[str]: return f"Condition updated on segment '{self._get_segment().name}'." diff --git a/api/segments/serializers.py b/api/segments/serializers.py index fccbeecf45aa..322357038eef 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -81,16 +81,31 @@ def create(self, validated_data): self._update_or_create_metadata(metadata_data, segment=segment) return segment - def update(self, instance, validated_data): + def update(self, instance: Segment, validated_data: dict[str, typing.Any]) -> None: # use the initial data since we need the ids included to determine which to update & which to create rules_data = self.initial_data.pop("rules", []) metadata_data = validated_data.pop("metadata", []) self.validate_segment_rules_conditions_limit(rules_data) - self._update_segment_rules(rules_data, segment=instance) - self._update_or_create_metadata(metadata_data, segment=instance) - # remove rules from validated data to prevent error trying to create segment with nested rules - del validated_data["rules"] - return super().update(instance, validated_data) + + # Create a version of the segment now that we're updating. + cloned_segment = instance.deep_clone() + + try: + self._update_segment_rules(rules_data, segment=instance) + self._update_or_create_metadata(metadata_data, segment=instance) + + # remove rules from validated data to prevent error trying to create segment with nested rules + del validated_data["rules"] + response = super().update(instance, validated_data) + except Exception: + # Since there was a problem during the update we now delete the cloned segment, + # since we no longer need a versioned segment. + instance.refresh_from_db() + instance.version = cloned_segment.version + instance.save() + cloned_segment.hard_delete() + raise + return response def validate_project_segment_limit(self, project: Project) -> None: if project.segments.count() >= project.max_segments_allowed: diff --git a/api/tests/unit/segments/test_unit_segments_models.py b/api/tests/unit/segments/test_unit_segments_models.py index de955abe95b5..351a2fc7cf90 100644 --- a/api/tests/unit/segments/test_unit_segments_models.py +++ b/api/tests/unit/segments/test_unit_segments_models.py @@ -1,6 +1,11 @@ +from unittest.mock import PropertyMock + import pytest from flag_engine.segments.constants import EQUAL, PERCENTAGE_SPLIT +from pytest_mock import MockerFixture +from features.models import Feature +from projects.models import Project from segments.models import Condition, Segment, SegmentRule @@ -167,3 +172,282 @@ def test_condition_get_update_log_message(segment, segment_rule, mocker): ) def test_segment_id_exists_in_rules_data(rules_data, expected_result): assert Segment.id_exists_in_rules_data(rules_data) == expected_result + + +def test_deep_clone_of_segment( + project: Project, + feature: Feature, +) -> None: + # Given + segment = Segment.objects.create( + name="SpecialSegment", + description="A lovely, special segment.", + project=project, + feature=feature, + ) + + # Check that the versioning is correct, since we'll be testing + # against it later in the test. + assert segment.version == 1 + assert segment.version_of == segment + + parent_rule = SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE) + + child_rule1 = SegmentRule.objects.create( + rule=parent_rule, type=SegmentRule.ANY_RULE + ) + child_rule2 = SegmentRule.objects.create( + rule=parent_rule, type=SegmentRule.NONE_RULE + ) + child_rule3 = SegmentRule.objects.create( + rule=parent_rule, type=SegmentRule.NONE_RULE + ) + child_rule4 = SegmentRule.objects.create( + rule=parent_rule, type=SegmentRule.ANY_RULE + ) + + condition1 = Condition.objects.create( + rule=child_rule1, + property="child_rule1", + operator=EQUAL, + value="condition3", + created_with_segment=True, + ) + condition2 = Condition.objects.create( + rule=child_rule2, + property="child_rule2", + operator=PERCENTAGE_SPLIT, + value="0.2", + created_with_segment=False, + ) + condition3 = Condition.objects.create( + rule=child_rule2, + property="child_rule2", + operator=EQUAL, + value="condition5", + created_with_segment=False, + ) + + condition4 = Condition.objects.create( + rule=child_rule3, + property="child_rule3", + operator=EQUAL, + value="condition6", + created_with_segment=False, + ) + + condition5 = Condition.objects.create( + rule=child_rule4, + property="child_rule4", + operator=EQUAL, + value="condition7", + created_with_segment=True, + ) + + # When + cloned_segment = segment.deep_clone() + + # Then + assert cloned_segment.name == segment.name + assert cloned_segment.description == segment.description + assert cloned_segment.project == project + assert cloned_segment.feature == feature + assert cloned_segment.version == 1 + assert cloned_segment.version_of == segment + + assert segment.version == 2 + + assert len(cloned_segment.rules.all()) == len(segment.rules.all()) == 1 + new_parent_rule = cloned_segment.rules.first() + + assert new_parent_rule.segment == cloned_segment + assert new_parent_rule.type == parent_rule.type + + assert len(new_parent_rule.rules.all()) == len(parent_rule.rules.all()) == 4 + new_child_rule1, new_child_rule2, new_child_rule3, new_child_rule4 = list( + new_parent_rule.rules.all() + ) + + assert new_child_rule1.type == child_rule1.type + assert new_child_rule2.type == child_rule2.type + assert new_child_rule3.type == child_rule3.type + assert new_child_rule4.type == child_rule4.type + + assert ( + len(new_parent_rule.conditions.all()) == len(parent_rule.conditions.all()) == 0 + ) + + assert ( + len(new_child_rule1.conditions.all()) == len(child_rule1.conditions.all()) == 1 + ) + new_condition1 = new_child_rule1.conditions.first() + + assert new_condition1.property == condition1.property + assert new_condition1.operator == condition1.operator + assert new_condition1.value == condition1.value + assert new_condition1.created_with_segment is condition1.created_with_segment + + assert ( + len(new_child_rule2.conditions.all()) == len(child_rule2.conditions.all()) == 2 + ) + new_condition2, new_condition3 = list(new_child_rule2.conditions.all()) + + assert new_condition2.property == condition2.property + assert new_condition2.operator == condition2.operator + assert new_condition2.value == condition2.value + assert new_condition2.created_with_segment is condition2.created_with_segment + + assert new_condition3.property == condition3.property + assert new_condition3.operator == condition3.operator + assert new_condition3.value == condition3.value + assert new_condition3.created_with_segment is condition3.created_with_segment + + assert ( + len(new_child_rule3.conditions.all()) == len(child_rule3.conditions.all()) == 1 + ) + new_condition4 = new_child_rule3.conditions.first() + + assert new_condition4.property == condition4.property + assert new_condition4.operator == condition4.operator + assert new_condition4.value == condition4.value + assert new_condition4.created_with_segment is condition4.created_with_segment + + assert ( + len(new_child_rule4.conditions.all()) == len(child_rule4.conditions.all()) == 1 + ) + new_condition5 = new_child_rule4.conditions.first() + + assert new_condition5.property == condition5.property + assert new_condition5.operator == condition5.operator + assert new_condition5.value == condition5.value + assert new_condition5.created_with_segment is condition5.created_with_segment + + +def test_manager_returns_only_highest_version_of_segments( + segment: Segment, +) -> None: + # Given + # The built-in segment fixture is pre-versioned already. + assert segment.version == 2 + assert segment.version_of == segment + + cloned_segment = segment.deep_clone() + assert cloned_segment.version == 2 + assert segment.version == 3 + + # When + queryset1 = Segment.objects.filter(id=cloned_segment.id) + queryset2 = Segment.all_objects.filter(id=cloned_segment.id) + queryset3 = Segment.objects.filter(id=segment.id) + queryset4 = Segment.all_objects.filter(id=segment.id) + + # Then + assert not queryset1.exists() + assert queryset2.first() == cloned_segment + assert queryset3.first() == segment + assert queryset4.first() == segment + + +def test_deep_clone_of_segment_with_improper_sub_rule( + project: Project, + feature: Feature, +) -> None: + # Given + segment = Segment.objects.create( + name="SpecialSegment", + description="A lovely, special segment.", + project=project, + feature=feature, + ) + + rule = SegmentRule.objects.create( + type=SegmentRule.ALL_RULE, + segment=segment, + ) + + # Rule with invalid relation to both segment and rule. + SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE, rule=rule) + + with pytest.raises(AssertionError) as exception: + segment.deep_clone() + + assert ( + "AssertionError: Unexpected rule, expecting segment set not rule" + == exception.exconly() + ) + + +def test_deep_clone_of_segment_with_grandchild_rule( + project: Project, + feature: Feature, +) -> None: + # Given + segment = Segment.objects.create( + name="SpecialSegment", + description="A lovely, special segment.", + project=project, + feature=feature, + ) + + parent_rule = SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE) + + child_rule = SegmentRule.objects.create(rule=parent_rule, type=SegmentRule.ANY_RULE) + + # Grandchild rule, which is invalid + SegmentRule.objects.create(rule=child_rule, type=SegmentRule.ANY_RULE) + + with pytest.raises(AssertionError) as exception: + segment.deep_clone() + + assert ( + "AssertionError: Expected two layers of rules, not more" == exception.exconly() + ) + + +def test_segment_rule_get_skip_create_audit_log_when_doesnt_skip( + segment: Segment, +) -> None: + # Given + assert segment == segment.version_of + segment_rule = SegmentRule.objects.create( + segment=segment, type=SegmentRule.ALL_RULE + ) + + # When + result = segment_rule.get_skip_create_audit_log() + + # Then + assert result is False + + +def test_segment_rule_get_skip_create_audit_log_when_skips(segment: Segment) -> None: + # Given + cloned_segment = segment.deep_clone() + assert cloned_segment != cloned_segment.version_of + + segment_rule = SegmentRule.objects.create( + segment=cloned_segment, type=SegmentRule.ALL_RULE + ) + + # When + result = segment_rule.get_skip_create_audit_log() + + # Then + assert result is True + + +def test_segment_get_skip_create_audit_log_when_exception( + mocker: MockerFixture, + segment: Segment, +) -> None: + # Given + patched_segment = mocker.patch.object( + Segment, "version_of_id", new_callable=PropertyMock + ) + patched_segment.side_effect = Segment.DoesNotExist("Segment missing") + + # When + result = segment.get_skip_create_audit_log() + + # Then + assert result is True diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 95b53c012a2f..a5b67952ecc7 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -9,6 +9,7 @@ from pytest_django import DjangoAssertNumQueries from pytest_django.fixtures import SettingsWrapper from pytest_lazyfixture import lazy_fixture +from pytest_mock import MockerFixture from rest_framework import status from rest_framework.test import APIClient @@ -157,7 +158,7 @@ def test_create_segments_reaching_max_limit(project, client, settings): "client", [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], ) -def test_audit_log_created_when_segment_updated(project, segment, client): +def test_audit_log_created_when_segment_updated(project, client): # Given segment = Segment.objects.create(name="Test segment", project=project) url = reverse( @@ -171,10 +172,11 @@ def test_audit_log_created_when_segment_updated(project, segment, client): } # When - res = client.put(url, data=json.dumps(data), content_type="application/json") + response = client.put(url, data=json.dumps(data), content_type="application/json") # Then - assert res.status_code == status.HTTP_200_OK + assert response.status_code == status.HTTP_200_OK + assert ( AuditLog.objects.filter( related_object_type=RelatedObjectType.SEGMENT.name @@ -250,6 +252,7 @@ def test_audit_log_created_when_segment_created(project, client): # Then assert res.status_code == status.HTTP_201_CREATED + assert ( AuditLog.objects.filter( related_object_type=RelatedObjectType.SEGMENT.name @@ -482,11 +485,12 @@ def test_create_segments_with_description_condition(project, client): assert segment_condition_description_value == "test-description" -@pytest.mark.parametrize( - "client", - [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], -) -def test_update_segment_add_new_condition(project, client, segment, segment_rule): +def test_update_segment_add_new_condition( + project: Project, + admin_client_new: APIClient, + segment: Segment, + segment_rule: SegmentRule, +) -> None: # Given url = reverse( "api-v1:projects:project-segments-detail", args=[project.id, segment.id] @@ -535,7 +539,9 @@ def test_update_segment_add_new_condition(project, client, segment, segment_rule } # When - response = client.put(url, data=json.dumps(data), content_type="application/json") + response = admin_client_new.put( + url, data=json.dumps(data), content_type="application/json" + ) # Then assert response.status_code == status.HTTP_200_OK @@ -548,6 +554,167 @@ def test_update_segment_add_new_condition(project, client, segment, segment_rule assert nested_rule.conditions.order_by("-id").first().value == new_condition_value +def test_update_segment_versioned_segment( + project: Project, + admin_client_new: APIClient, + segment: Segment, + segment_rule: SegmentRule, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-detail", args=[project.id, segment.id] + ) + nested_rule = SegmentRule.objects.create( + rule=segment_rule, type=SegmentRule.ANY_RULE + ) + existing_condition = Condition.objects.create( + rule=nested_rule, property="foo", operator=EQUAL, value="bar" + ) + + # Before updating the segment confirm pre-existing version count which is + # automatically set by the fixture. + assert Segment.all_objects.filter(version_of=segment).count() == 2 + + new_condition_property = "foo2" + new_condition_value = "bar" + data = { + "name": segment.name, + "project": project.id, + "rules": [ + { + "id": segment_rule.id, + "type": segment_rule.type, + "rules": [ + { + "id": nested_rule.id, + "type": nested_rule.type, + "rules": [], + "conditions": [ + # existing condition + { + "id": existing_condition.id, + "property": existing_condition.property, + "operator": existing_condition.operator, + "value": existing_condition.value, + }, + # new condition + { + "property": new_condition_property, + "operator": EQUAL, + "value": new_condition_value, + }, + ], + } + ], + "conditions": [], + } + ], + } + + # When + response = admin_client_new.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + + # Now verify that a new versioned segment has been set. + assert Segment.all_objects.filter(version_of=segment).count() == 3 + + # Now check the previously versioned segment to match former count of conditions. + + versioned_segment = Segment.all_objects.filter( + version_of=segment, version=2 + ).first() + assert versioned_segment != segment + assert versioned_segment.rules.count() == 1 + versioned_rule = versioned_segment.rules.first() + assert versioned_rule.rules.count() == 1 + + nested_versioned_rule = versioned_rule.rules.first() + assert nested_versioned_rule.conditions.count() == 1 + versioned_condition = nested_versioned_rule.conditions.first() + assert versioned_condition != existing_condition + assert versioned_condition.property == existing_condition.property + + +def test_update_segment_versioned_segment_with_thrown_exception( + project: Project, + admin_client_new: APIClient, + segment: Segment, + segment_rule: SegmentRule, + mocker: MockerFixture, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-detail", args=[project.id, segment.id] + ) + nested_rule = SegmentRule.objects.create( + rule=segment_rule, type=SegmentRule.ANY_RULE + ) + existing_condition = Condition.objects.create( + rule=nested_rule, property="foo", operator=EQUAL, value="bar" + ) + + assert ( + segment.version == 2 == Segment.all_objects.filter(version_of=segment).count() + ) + + new_condition_property = "foo2" + new_condition_value = "bar" + data = { + "name": segment.name, + "project": project.id, + "rules": [ + { + "id": segment_rule.id, + "type": segment_rule.type, + "rules": [ + { + "id": nested_rule.id, + "type": nested_rule.type, + "rules": [], + "conditions": [ + { + "id": existing_condition.id, + "property": existing_condition.property, + "operator": existing_condition.operator, + "value": existing_condition.value, + }, + { + "property": new_condition_property, + "operator": EQUAL, + "value": new_condition_value, + }, + ], + } + ], + "conditions": [], + } + ], + } + + update_super_patch = mocker.patch( + "rest_framework.serializers.ModelSerializer.update" + ) + update_super_patch.side_effect = Exception("Mocked exception") + + # When + with pytest.raises(Exception): + admin_client_new.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + segment.refresh_from_db() + + # Now verify that the version of the segment has not been changed. + assert ( + segment.version == 2 == Segment.all_objects.filter(version_of=segment).count() + ) + + @pytest.mark.parametrize( "client", [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")],