diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d086fc3c2671d..bce63d905111d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: exclude: migrations - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 name: flake8 diff --git a/api/environments/dynamodb/wrappers/exceptions.py b/api/environments/dynamodb/wrappers/exceptions.py new file mode 100644 index 0000000000000..e9e70f03c01be --- /dev/null +++ b/api/environments/dynamodb/wrappers/exceptions.py @@ -0,0 +1,11 @@ +from decimal import Decimal + + +class CapacityBudgetExceeded(Exception): + def __init__( + self, + capacity_budget: Decimal, + capacity_spent: Decimal, + ) -> None: + self.capacity_budget = capacity_budget + self.capacity_spent = capacity_spent diff --git a/api/environments/dynamodb/wrappers/identity_wrapper.py b/api/environments/dynamodb/wrappers/identity_wrapper.py index 9bf91c6d9516b..7e60ea1714868 100644 --- a/api/environments/dynamodb/wrappers/identity_wrapper.py +++ b/api/environments/dynamodb/wrappers/identity_wrapper.py @@ -1,6 +1,7 @@ import logging import typing from contextlib import suppress +from decimal import Decimal from typing import Iterable from boto3.dynamodb.conditions import Key @@ -12,6 +13,7 @@ from rest_framework.exceptions import NotFound from environments.dynamodb.constants import IDENTITIES_PAGINATION_LIMIT +from environments.dynamodb.wrappers.exceptions import CapacityBudgetExceeded from util.mappers import map_identity_to_identity_document from .base import BaseDynamoWrapper @@ -89,7 +91,8 @@ def get_all_items( environment_api_key: str, limit: int, start_key: dict[str, "TableAttributeValueTypeDef"] | None = None, - projection_expression: str = None, + projection_expression: str | None = None, + return_consumed_capacity: bool = False, ) -> "QueryOutputTableTypeDef": filter_expression = Key("environment_api_key").eq(environment_api_key) query_kwargs: "QueryInputRequestTypeDef" = { @@ -97,11 +100,12 @@ def get_all_items( "KeyConditionExpression": filter_expression, "Limit": limit, } - if projection_expression: - query_kwargs["ProjectionExpression"] = projection_expression - if start_key: query_kwargs["ExclusiveStartKey"] = start_key + if projection_expression: + query_kwargs["ProjectionExpression"] = projection_expression + if return_consumed_capacity: + query_kwargs["ReturnConsumedCapacity"] = "TOTAL" return self.query_items(**query_kwargs) def iter_all_items_paginated( @@ -109,17 +113,27 @@ def iter_all_items_paginated( environment_api_key: str, limit: int = IDENTITIES_PAGINATION_LIMIT, projection_expression: str = None, + capacity_budget: Decimal = Decimal("Inf"), ) -> typing.Generator[dict, None, None]: last_evaluated_key = "initial" get_all_items_kwargs = { "environment_api_key": environment_api_key, "limit": limit, "projection_expression": projection_expression, + "return_consumed_capacity": capacity_budget != Decimal("Inf"), } + capacity_spent = 0 while last_evaluated_key: + if capacity_spent >= capacity_budget: + raise CapacityBudgetExceeded( + capacity_budget=capacity_budget, + capacity_spent=capacity_spent, + ) query_response = self.get_all_items( **get_all_items_kwargs, ) + with suppress(KeyError): + capacity_spent += query_response["ConsumedCapacity"]["CapacityUnits"] for item in query_response["Items"]: yield item if last_evaluated_key := query_response.get("LastEvaluatedKey"): diff --git a/api/tests/unit/environments/dynamodb/wrappers/test_unit_dynamodb_identity_wrapper.py b/api/tests/unit/environments/dynamodb/wrappers/test_unit_dynamodb_identity_wrapper.py index 96ed023760dc5..3beee0c1c671f 100644 --- a/api/tests/unit/environments/dynamodb/wrappers/test_unit_dynamodb_identity_wrapper.py +++ b/api/tests/unit/environments/dynamodb/wrappers/test_unit_dynamodb_identity_wrapper.py @@ -1,4 +1,5 @@ import typing +from decimal import Decimal import pytest from boto3.dynamodb.conditions import Key @@ -11,6 +12,7 @@ from rest_framework.exceptions import NotFound from environments.dynamodb import DynamoIdentityWrapper +from environments.dynamodb.wrappers.exceptions import CapacityBudgetExceeded from environments.identities.models import Identity from environments.identities.traits.models import Trait from segments.models import Condition, Segment, SegmentRule @@ -153,6 +155,31 @@ def test_get_all_items_with_start_key_calls_query_with_correct_arguments(mocker) ) +def test_get_all_items__return_consumed_capacity_true__calls_expected( + mocker: MockerFixture, +) -> None: + # Given + dynamo_identity_wrapper = DynamoIdentityWrapper() + + environment_key = "environment_key" + mocked_dynamo_table = mocker.patch.object(dynamo_identity_wrapper, "_table") + + # When + dynamo_identity_wrapper.get_all_items( + environment_api_key=environment_key, + limit=999, + return_consumed_capacity=True, + ) + + # Then + mocked_dynamo_table.query.assert_called_with( + IndexName="environment_api_key-identifier-index", + Limit=999, + KeyConditionExpression=Key("environment_api_key").eq(environment_key), + ReturnConsumedCapacity="TOTAL", + ) + + def test_search_items_with_identifier_calls_query_with_correct_arguments(mocker): dynamo_identity_wrapper = DynamoIdentityWrapper() environment_key = "environment_key" @@ -398,7 +425,6 @@ def test_get_segment_ids_with_identity_model(identity, environment, mocker): def test_identity_wrapper__iter_all_items_paginated__returns_expected( - environment: "Environment", identity: "Identity", mocker: "MockerFixture", ) -> None: @@ -410,13 +436,6 @@ def test_identity_wrapper__iter_all_items_paginated__returns_expected( expected_next_page_key = "next_page_key" - environment_document = map_environment_to_environment_document(environment) - mocked_environment_wrapper = mocker.patch( - "environments.dynamodb.wrappers.environment_wrapper.DynamoEnvironmentWrapper", - autospec=True, - ) - mocked_environment_wrapper.return_value.get_item.return_value = environment_document - mocked_get_all_items = mocker.patch.object( dynamo_identity_wrapper, "get_all_items", @@ -445,13 +464,89 @@ def test_identity_wrapper__iter_all_items_paginated__returns_expected( [ mocker.call( environment_api_key=environment_api_key, + limit=limit, + projection_expression=None, + return_consumed_capacity=False, + ), + mocker.call( + environment_api_key=environment_api_key, + limit=limit, projection_expression=None, + return_consumed_capacity=False, + start_key=expected_next_page_key, + ), + ] + ) + + +@pytest.mark.parametrize("capacity_budget", [Decimal("2.0"), Decimal("2.2")]) +def test_identity_wrapper__iter_all_items_paginated__capacity_budget_set__raises_expected( + identity: "Identity", + mocker: "MockerFixture", + capacity_budget: Decimal, +) -> None: + # Given + dynamo_identity_wrapper = DynamoIdentityWrapper() + identity_document = map_identity_to_identity_document(identity) + environment_api_key = "test_api_key" + limit = 1 + + expected_next_page_key = "next_page_key" + + mocked_get_all_items = mocker.patch.object( + dynamo_identity_wrapper, + "get_all_items", + autospec=True, + ) + mocked_get_all_items.side_effect = [ + { + "Items": [identity_document], + "LastEvaluatedKey": "next_page_key", + "ConsumedCapacity": {"CapacityUnits": Decimal("1.1")}, + }, + { + "Items": [identity_document], + "LastEvaluatedKey": "next_after_next_page_key", + "ConsumedCapacity": {"CapacityUnits": Decimal("1.1")}, + }, + { + "Items": [identity_document], + "LastEvaluatedKey": None, + "ConsumedCapacity": {"CapacityUnits": Decimal("1.1")}, + }, + ] + + # When + iterator = dynamo_identity_wrapper.iter_all_items_paginated( + environment_api_key=environment_api_key, + limit=limit, + capacity_budget=capacity_budget, + ) + result_1 = next(iterator) + result_2 = next(iterator) + + # Then + with pytest.raises(CapacityBudgetExceeded) as exc_info: + next(iterator) + + assert result_1 == identity_document + assert result_2 == identity_document + assert exc_info.value.capacity_budget == capacity_budget + assert exc_info.value.capacity_spent == Decimal("2.2") + + mocked_get_all_items.assert_has_calls( + [ + mocker.call( + environment_api_key=environment_api_key, limit=limit, + projection_expression=None, + return_consumed_capacity=True, ), mocker.call( environment_api_key=environment_api_key, limit=limit, projection_expression=None, + return_consumed_capacity=True, start_key=expected_next_page_key, ), ]