diff --git a/api/projects/serializers.py b/api/projects/serializers.py index 447b56879fea..54e13fc8656f 100644 --- a/api/projects/serializers.py +++ b/api/projects/serializers.py @@ -49,11 +49,6 @@ class Meta: "edge_v2_migration_status", ) - def update(self, instance: Project, validated_data: dict) -> Project: - # Prevent updates to `organisation` field - validated_data.pop("organisation", None) - return super().update(instance, validated_data) - def get_migration_status(self, obj: Project) -> str: if not settings.PROJECT_METADATA_TABLE_NAME_DYNAMO: migration_status = ProjectIdentityMigrationStatus.NOT_APPLICABLE.value @@ -74,9 +69,7 @@ def get_use_edge_identities(self, obj: Project) -> bool: ) -class ProjectUpdateOrCreateSerializer( - ReadOnlyIfNotValidPlanMixin, ProjectListSerializer -): +class ProjectCreateSerializer(ReadOnlyIfNotValidPlanMixin, ProjectListSerializer): invalid_plans_regex = r"^(free|startup.*|scale-up.*)$" field_names = ("stale_flags_limit_days", "enable_realtime_updates") @@ -98,6 +91,13 @@ def get_subscription(self) -> typing.Optional[Subscription]: return None +class ProjectUpdateSerializer(ProjectCreateSerializer): + class Meta(ProjectCreateSerializer.Meta): + read_only_fields = ProjectCreateSerializer.Meta.read_only_fields + ( + "organisation", + ) + + class ProjectRetrieveSerializer(ProjectListSerializer): total_features = serializers.SerializerMethodField() total_segments = serializers.SerializerMethodField() diff --git a/api/projects/views.py b/api/projects/views.py index 42929321015e..27c720f6a138 100644 --- a/api/projects/views.py +++ b/api/projects/views.py @@ -44,9 +44,10 @@ CreateUpdateUserProjectPermissionSerializer, ListUserPermissionGroupProjectPermissionSerializer, ListUserProjectPermissionSerializer, + ProjectCreateSerializer, ProjectListSerializer, ProjectRetrieveSerializer, - ProjectUpdateOrCreateSerializer, + ProjectUpdateSerializer, ) @@ -75,11 +76,13 @@ class ProjectViewSet(viewsets.ModelViewSet): permission_classes = [ProjectPermissions] def get_serializer_class(self): - if self.action == "retrieve": - return ProjectRetrieveSerializer - elif self.action in ("create", "update", "partial_update"): - return ProjectUpdateOrCreateSerializer - return ProjectListSerializer + serializers = { + "retrieve": ProjectRetrieveSerializer, + "create": ProjectCreateSerializer, + "update": ProjectUpdateSerializer, + "partial_update": ProjectUpdateSerializer, + } + return serializers.get(self.action, ProjectListSerializer) pagination_class = None diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 7afc966ccdb8..0a0cf6495c04 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -56,11 +56,13 @@ def organisation_with_persist_trait_data_disabled(organisation): @pytest.fixture() -def dynamo_enabled_project(admin_client, organisation): +def dynamo_enabled_project( + admin_client: APIClient, organisation: Organisation, settings: SettingsWrapper +): + settings.EDGE_ENABLED = True project_data = { "name": "Test Project", "organisation": organisation, - "enable_dynamo_db": True, } url = reverse("api-v1:projects:project-list") response = admin_client.post(url, data=project_data)