diff --git a/api/conftest.py b/api/conftest.py index 334794d88339..d5249fad1554 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -59,7 +59,13 @@ CREATE_PROJECT, MANAGE_USER_GROUPS, ) -from organisations.subscriptions.constants import CHARGEBEE, XERO +from organisations.subscriptions.constants import ( + CHARGEBEE, + FREE_PLAN_ID, + SCALE_UP, + STARTUP, + XERO, +) from permissions.models import PermissionModel from projects.models import ( Project, @@ -324,6 +330,33 @@ def enterprise_subscription(organisation: Organisation) -> Subscription: return organisation.subscription +@pytest.fixture() +def startup_subscription(organisation: Organisation) -> Subscription: + Subscription.objects.filter(organisation=organisation).update( + plan=STARTUP, subscription_id="subscription-id" + ) + organisation.refresh_from_db() + return organisation.subscription + + +@pytest.fixture() +def scale_up_subscription(organisation: Organisation) -> Subscription: + Subscription.objects.filter(organisation=organisation).update( + plan=SCALE_UP, subscription_id="subscription-id" + ) + organisation.refresh_from_db() + return organisation.subscription + + +@pytest.fixture() +def free_subscription(organisation: Organisation) -> Subscription: + Subscription.objects.filter(organisation=organisation).update( + plan=FREE_PLAN_ID, subscription_id="subscription-id" + ) + organisation.refresh_from_db() + return organisation.subscription + + @pytest.fixture() def project(organisation): return Project.objects.create(name="Test Project", organisation=organisation) diff --git a/api/projects/serializers.py b/api/projects/serializers.py index a331a2fa73e5..14f8af3dfa84 100644 --- a/api/projects/serializers.py +++ b/api/projects/serializers.py @@ -69,7 +69,7 @@ class ProjectUpdateOrCreateSerializer( ReadOnlyIfNotValidPlanMixin, ProjectListSerializer ): invalid_plans_regex = r"^(free|startup.*|scale-up.*)$" - field_names = ("stale_flags_limit_days",) + field_names = ("stale_flags_limit_days", "enable_realtime_updates") def get_subscription(self) -> typing.Optional[Subscription]: view = self.context["view"] diff --git a/api/tests/unit/projects/test_unit_projects_views.py b/api/tests/unit/projects/test_unit_projects_views.py index 49a18985201f..7f949f3c6fc5 100644 --- a/api/tests/unit/projects/test_unit_projects_views.py +++ b/api/tests/unit/projects/test_unit_projects_views.py @@ -687,11 +687,20 @@ def test_get_project_by_uuid(client, project, mocker, settings, organisation): @pytest.mark.parametrize( - "client", - [(lazy_fixture("admin_master_api_key_client")), (lazy_fixture("admin_client"))], + "subscription, can_update_realtime", + [ + (lazy_fixture("free_subscription"), False), + (lazy_fixture("startup_subscription"), False), + (lazy_fixture("scale_up_subscription"), False), + (lazy_fixture("enterprise_subscription"), True), + ], ) -def test_can_enable_realtime_updates_for_project( - client, project, mocker, settings, organisation +def test_can_enable_realtime_updates_for_enterprise( + admin_client: APIClient, + project: Project, + organisation: Organisation, + subscription: Subscription, + can_update_realtime: bool, ): # Given url = reverse("api-v1:projects:project-detail", args=[project.id]) @@ -703,12 +712,12 @@ def test_can_enable_realtime_updates_for_project( } # When - response = client.put(url, data=data) + response = admin_client.put(url, data=data) # Then assert response.status_code == status.HTTP_200_OK assert response.json()["uuid"] == str(project.uuid) - assert response.json()["enable_realtime_updates"] is True + assert response.json()["enable_realtime_updates"] is can_update_realtime @pytest.mark.parametrize(