Skip to content

Commit

Permalink
feat: Create versioning for segments (#4138)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Elwell <matthew.elwell@flagsmith.com>
  • Loading branch information
zachaysan and matthewelwell authored Jun 27, 2024
1 parent 5162687 commit bc9b340
Show file tree
Hide file tree
Showing 10 changed files with 721 additions and 19 deletions.
9 changes: 7 additions & 2 deletions api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,13 @@ def project(organisation):


@pytest.fixture()
def segment(project):
return Segment.objects.create(name="segment", project=project)
def segment(project: Project):
_segment = Segment.objects.create(name="segment", project=project)
# Deep clone the segment to ensure that any bugs around
# versioning get bubbled up through the test suite.
_segment.deep_clone()

return _segment


@pytest.fixture()
Expand Down
18 changes: 17 additions & 1 deletion api/core/signals.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import logging

from core.models import AbstractBaseAuditableModel
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.utils import timezone
from simple_history.models import HistoricalRecords

from audit import tasks
from task_processor.task_run_method import TaskRunMethod
from users.models import FFAdminUser

logger = logging.getLogger(__name__)


def create_audit_log_from_historical_record(
instance: AbstractBaseAuditableModel,
Expand All @@ -30,7 +35,18 @@ def create_audit_log_from_historical_record(
else None
)

environment, project = instance.get_environment_and_project()
try:
environment, project = instance.get_environment_and_project()
except ObjectDoesNotExist:
logger.warning(
"Unable to create audit log for %s %s. "
"Parent model does not exist - this likely means it was hard deleted.",
instance.related_object_type,
getattr(instance, "id", "uuid"),
exc_info=True,
)
return

if project != history_instance.instance and (
(project and project.deleted_at)
or (environment and environment.project.deleted_at)
Expand Down
16 changes: 16 additions & 0 deletions api/segments/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class SegmentAuditLogHelper:
def __init__(self) -> None:
self.skip_audit_log = {}

def should_skip_audit_log(self, segment_id: int) -> None | bool:
return self.skip_audit_log.get(segment_id)

def set_skip_audit_log(self, segment_id: int) -> None:
self.skip_audit_log[segment_id] = True

def unset_skip_audit_log(self, segment_id: int) -> None:
if segment_id in self.skip_audit_log:
del self.skip_audit_log[segment_id]


segment_audit_log_helper = SegmentAuditLogHelper()
11 changes: 11 additions & 0 deletions api/segments/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from core.models import SoftDeleteExportableManager
from django.db.models import F


class SegmentManager(SoftDeleteExportableManager):
def get_queryset(self):
"""
Returns only the canonical segments, which will always be
the highest version.
"""
return super().get_queryset().filter(id=F("version_of"))
58 changes: 58 additions & 0 deletions api/segments/migrations/0023_add_versioning_to_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Generated by Django 3.2.25 on 2024-06-10 15:31
import os

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("segments", "0022_add_soft_delete_to_segment_rules_and_conditions"),
]

operations = [
migrations.AddField(
model_name="historicalsegment",
name="version",
field=models.IntegerField(null=True),
),
migrations.AddField(
model_name="historicalsegment",
name="version_of",
field=models.ForeignKey(
blank=True,
db_constraint=False,
null=True,
on_delete=django.db.models.deletion.DO_NOTHING,
related_name="+",
to="segments.segment",
),
),
migrations.AddField(
model_name="segment",
name="version",
field=models.IntegerField(null=True),
),
migrations.AddField(
model_name="segment",
name="version_of",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="versioned_segments",
to="segments.segment",
),
),
migrations.RunSQL(
sql=open(
os.path.join(
os.path.dirname(__file__),
"sql",
"0023_add_versioning_to_segments.sql",
)
).read(),
reverse_sql=migrations.RunSQL.noop,
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE segments_segment SET version_of_id = id;
131 changes: 130 additions & 1 deletion api/segments/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import logging
import typing
import uuid
from copy import deepcopy

from core.models import (
SoftDeleteExportableManager,
SoftDeleteExportableModel,
abstract_base_auditable_model_factory,
)
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.core.exceptions import ValidationError
from django.db import models
from django_lifecycle import (
AFTER_CREATE,
BEFORE_CREATE,
LifecycleModelMixin,
hook,
)
from flag_engine.segments import constants

from audit.constants import (
Expand All @@ -22,10 +30,14 @@
from metadata.models import Metadata
from projects.models import Project

from .helpers import segment_audit_log_helper
from .managers import SegmentManager

logger = logging.getLogger(__name__)


class Segment(
LifecycleModelMixin,
SoftDeleteExportableModel,
abstract_base_auditable_model_factory(["uuid"]),
):
Expand All @@ -45,14 +57,44 @@ class Segment(
Feature, on_delete=models.CASCADE, related_name="segments", null=True
)

# This defaults to 1 for newly created segments.
version = models.IntegerField(null=True)

# The related_name is not useful without specifying all_objects as a manager.
version_of = models.ForeignKey(
"self",
on_delete=models.CASCADE,
related_name="versioned_segments",
null=True,
blank=True,
)
metadata = GenericRelation(Metadata)

# Only serves segments that are the canonical version.
objects = SegmentManager()

# Includes versioned segments.
all_objects = SoftDeleteExportableManager()

class Meta:
ordering = ("id",) # explicit ordering to prevent pagination warnings

def __str__(self):
return "Segment - %s" % self.name

def get_skip_create_audit_log(self) -> bool:
skip = segment_audit_log_helper.should_skip_audit_log(self.id)
if skip is not None:
return skip

try:
if self.version_of_id and self.version_of_id != self.id:
return True
except Segment.DoesNotExist:
return True

return False

@staticmethod
def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool:
"""
Expand Down Expand Up @@ -84,6 +126,49 @@ def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool:

return False

@hook(BEFORE_CREATE, when="version_of", is_now=None)
def set_default_version_to_one_if_new_segment(self):
if self.version is None:
self.version = 1

@hook(AFTER_CREATE, when="version_of", is_now=None)
def set_version_of_to_self_if_none(self):
"""
This allows the segment model to reference all versions of
itself including itself.
"""
segment_audit_log_helper.set_skip_audit_log(self.id)
self.version_of = self
self.save()
segment_audit_log_helper.unset_skip_audit_log(self.id)

def deep_clone(self) -> "Segment":
cloned_segment = deepcopy(self)
cloned_segment.id = None
cloned_segment.uuid = uuid.uuid4()
cloned_segment.version_of = self
cloned_segment.save()

segment_audit_log_helper.set_skip_audit_log(self.id)
self.version += 1
self.save()
segment_audit_log_helper.unset_skip_audit_log(self.id)

cloned_rules = []
for rule in self.rules.all():
cloned_rule = rule.deep_clone(cloned_segment)
cloned_rules.append(cloned_rule)

cloned_segment.refresh_from_db()

assert (
len(self.rules.all())
== len(cloned_rules)
== len(cloned_segment.rules.all())
), "Mismatch during rules creation"

return cloned_segment

def get_create_log_message(self, history_instance) -> typing.Optional[str]:
return SEGMENT_CREATED_MESSAGE % self.name

Expand Down Expand Up @@ -128,6 +213,10 @@ def __str__(self):
str(self.segment) if self.segment else str(self.rule),
)

def get_skip_create_audit_log(self) -> bool:
segment = self.get_segment()
return segment.version_of_id != segment.id

def get_segment(self):
"""
rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the
Expand All @@ -136,10 +225,46 @@ def get_segment(self):
TODO: denormalise the segment information so that we don't have to make multiple queries here in complex cases
"""
rule = self
while not rule.segment:
while not rule.segment_id:
rule = rule.rule
return rule.segment

def deep_clone(self, cloned_segment: Segment) -> "SegmentRule":
if self.rule:
# Since we're expecting a rule that is only belonging to a
# segment, since a rule either belongs to a segment xor belongs
# to a rule, we don't expect there also to be a rule associated.
assert False, "Unexpected rule, expecting segment set not rule"
cloned_rule = deepcopy(self)
cloned_rule.segment = cloned_segment
cloned_rule.uuid = uuid.uuid4()
cloned_rule.id = None
cloned_rule.save()

# Conditions are only part of the sub-rules.
assert self.conditions.exists() is False

for sub_rule in self.rules.all():
if sub_rule.rules.exists():
assert False, "Expected two layers of rules, not more"

cloned_sub_rule = deepcopy(sub_rule)
cloned_sub_rule.rule = cloned_rule
cloned_sub_rule.uuid = uuid.uuid4()
cloned_sub_rule.id = None
cloned_sub_rule.save()

cloned_conditions = []
for condition in sub_rule.conditions.all():
cloned_condition = deepcopy(condition)
cloned_condition.rule = cloned_sub_rule
cloned_condition.uuid = uuid.uuid4()
cloned_condition.id = None
cloned_conditions.append(cloned_condition)
Condition.objects.bulk_create(cloned_conditions)

return cloned_rule


class Condition(
SoftDeleteExportableModel, abstract_base_auditable_model_factory(["uuid"])
Expand Down Expand Up @@ -188,6 +313,10 @@ def __str__(self):
self.value,
)

def get_skip_create_audit_log(self) -> bool:
segment = self.rule.get_segment()
return segment.version_of_id != segment.id

def get_update_log_message(self, history_instance) -> typing.Optional[str]:
return f"Condition updated on segment '{self._get_segment().name}'."

Expand Down
27 changes: 21 additions & 6 deletions api/segments/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,31 @@ def create(self, validated_data):
self._update_or_create_metadata(metadata_data, segment=segment)
return segment

def update(self, instance, validated_data):
def update(self, instance: Segment, validated_data: dict[str, typing.Any]) -> None:
# use the initial data since we need the ids included to determine which to update & which to create
rules_data = self.initial_data.pop("rules", [])
metadata_data = validated_data.pop("metadata", [])
self.validate_segment_rules_conditions_limit(rules_data)
self._update_segment_rules(rules_data, segment=instance)
self._update_or_create_metadata(metadata_data, segment=instance)
# remove rules from validated data to prevent error trying to create segment with nested rules
del validated_data["rules"]
return super().update(instance, validated_data)

# Create a version of the segment now that we're updating.
cloned_segment = instance.deep_clone()

try:
self._update_segment_rules(rules_data, segment=instance)
self._update_or_create_metadata(metadata_data, segment=instance)

# remove rules from validated data to prevent error trying to create segment with nested rules
del validated_data["rules"]
response = super().update(instance, validated_data)
except Exception:
# Since there was a problem during the update we now delete the cloned segment,
# since we no longer need a versioned segment.
instance.refresh_from_db()
instance.version = cloned_segment.version
instance.save()
cloned_segment.hard_delete()
raise
return response

def validate_project_segment_limit(self, project: Project) -> None:
if project.segments.count() >= project.max_segments_allowed:
Expand Down
Loading

0 comments on commit bc9b340

Please sign in to comment.