Skip to content

Commit

Permalink
Add ability to restrict capacity on identity iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
khvn26 committed May 3, 2024
1 parent 92dd0fa commit 740a702
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
exclude: migrations

- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
name: flake8
Expand Down
11 changes: 11 additions & 0 deletions api/environments/dynamodb/wrappers/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from decimal import Decimal


class CapacityBudgetExceeded(Exception):
def __init__(
self,
capacity_budget: Decimal,
capacity_spent: Decimal,
) -> None:
self.capacity_budget = capacity_budget
self.capacity_spent = capacity_spent
22 changes: 18 additions & 4 deletions api/environments/dynamodb/wrappers/identity_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
from contextlib import suppress
from decimal import Decimal
from typing import Iterable

from boto3.dynamodb.conditions import Key
Expand All @@ -12,6 +13,7 @@
from rest_framework.exceptions import NotFound

from environments.dynamodb.constants import IDENTITIES_PAGINATION_LIMIT
from environments.dynamodb.wrappers.exceptions import CapacityBudgetExceeded
from util.mappers import map_identity_to_identity_document

from .base import BaseDynamoWrapper
Expand Down Expand Up @@ -89,37 +91,49 @@ def get_all_items(
environment_api_key: str,
limit: int,
start_key: dict[str, "TableAttributeValueTypeDef"] | None = None,
projection_expression: str = None,
projection_expression: str | None = None,
return_consumed_capacity: bool = False,
) -> "QueryOutputTableTypeDef":
filter_expression = Key("environment_api_key").eq(environment_api_key)
query_kwargs: "QueryInputRequestTypeDef" = {
"IndexName": "environment_api_key-identifier-index",
"KeyConditionExpression": filter_expression,
"Limit": limit,
}
if projection_expression:
query_kwargs["ProjectionExpression"] = projection_expression

if start_key:
query_kwargs["ExclusiveStartKey"] = start_key
if projection_expression:
query_kwargs["ProjectionExpression"] = projection_expression
if return_consumed_capacity:
query_kwargs["ReturnConsumedCapacity"] = "TOTAL"
return self.query_items(**query_kwargs)

def iter_all_items_paginated(
self,
environment_api_key: str,
limit: int = IDENTITIES_PAGINATION_LIMIT,
projection_expression: str = None,
capacity_budget: Decimal = Decimal("Inf"),
) -> typing.Generator[dict, None, None]:
last_evaluated_key = "initial"
get_all_items_kwargs = {
"environment_api_key": environment_api_key,
"limit": limit,
"projection_expression": projection_expression,
"return_consumed_capacity": capacity_budget != Decimal("Inf"),
}
capacity_spent = 0
while last_evaluated_key:
if capacity_spent >= capacity_budget:
raise CapacityBudgetExceeded(
capacity_budget=capacity_budget,
capacity_spent=capacity_spent,
)
query_response = self.get_all_items(
**get_all_items_kwargs,
)
with suppress(KeyError):
capacity_spent += query_response["ConsumedCapacity"]["CapacityUnits"]
for item in query_response["Items"]:
yield item
if last_evaluated_key := query_response.get("LastEvaluatedKey"):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from decimal import Decimal

import pytest
from boto3.dynamodb.conditions import Key
Expand All @@ -11,6 +12,7 @@
from rest_framework.exceptions import NotFound

from environments.dynamodb import DynamoIdentityWrapper
from environments.dynamodb.wrappers.exceptions import CapacityBudgetExceeded
from environments.identities.models import Identity
from environments.identities.traits.models import Trait
from segments.models import Condition, Segment, SegmentRule
Expand Down Expand Up @@ -153,6 +155,31 @@ def test_get_all_items_with_start_key_calls_query_with_correct_arguments(mocker)
)


def test_get_all_items__return_consumed_capacity_true__calls_expected(
mocker: MockerFixture,
) -> None:
# Given
dynamo_identity_wrapper = DynamoIdentityWrapper()

environment_key = "environment_key"
mocked_dynamo_table = mocker.patch.object(dynamo_identity_wrapper, "_table")

# When
dynamo_identity_wrapper.get_all_items(
environment_api_key=environment_key,
limit=999,
return_consumed_capacity=True,
)

# Then
mocked_dynamo_table.query.assert_called_with(
IndexName="environment_api_key-identifier-index",
Limit=999,
KeyConditionExpression=Key("environment_api_key").eq(environment_key),
ReturnConsumedCapacity="TOTAL",
)


def test_search_items_with_identifier_calls_query_with_correct_arguments(mocker):
dynamo_identity_wrapper = DynamoIdentityWrapper()
environment_key = "environment_key"
Expand Down Expand Up @@ -398,7 +425,6 @@ def test_get_segment_ids_with_identity_model(identity, environment, mocker):


def test_identity_wrapper__iter_all_items_paginated__returns_expected(
environment: "Environment",
identity: "Identity",
mocker: "MockerFixture",
) -> None:
Expand All @@ -410,13 +436,6 @@ def test_identity_wrapper__iter_all_items_paginated__returns_expected(

expected_next_page_key = "next_page_key"

environment_document = map_environment_to_environment_document(environment)
mocked_environment_wrapper = mocker.patch(
"environments.dynamodb.wrappers.environment_wrapper.DynamoEnvironmentWrapper",
autospec=True,
)
mocked_environment_wrapper.return_value.get_item.return_value = environment_document

mocked_get_all_items = mocker.patch.object(
dynamo_identity_wrapper,
"get_all_items",
Expand Down Expand Up @@ -445,13 +464,89 @@ def test_identity_wrapper__iter_all_items_paginated__returns_expected(
[
mocker.call(
environment_api_key=environment_api_key,
limit=limit,
projection_expression=None,
return_consumed_capacity=False,
),
mocker.call(
environment_api_key=environment_api_key,
limit=limit,
projection_expression=None,
return_consumed_capacity=False,
start_key=expected_next_page_key,
),
]
)


@pytest.mark.parametrize("capacity_budget", [Decimal("2.0"), Decimal("2.2")])
def test_identity_wrapper__iter_all_items_paginated__capacity_budget_set__raises_expected(
identity: "Identity",
mocker: "MockerFixture",
capacity_budget: Decimal,
) -> None:
# Given
dynamo_identity_wrapper = DynamoIdentityWrapper()
identity_document = map_identity_to_identity_document(identity)
environment_api_key = "test_api_key"
limit = 1

expected_next_page_key = "next_page_key"

mocked_get_all_items = mocker.patch.object(
dynamo_identity_wrapper,
"get_all_items",
autospec=True,
)
mocked_get_all_items.side_effect = [
{
"Items": [identity_document],
"LastEvaluatedKey": "next_page_key",
"ConsumedCapacity": {"CapacityUnits": Decimal("1.1")},
},
{
"Items": [identity_document],
"LastEvaluatedKey": "next_after_next_page_key",
"ConsumedCapacity": {"CapacityUnits": Decimal("1.1")},
},
{
"Items": [identity_document],
"LastEvaluatedKey": None,
"ConsumedCapacity": {"CapacityUnits": Decimal("1.1")},
},
]

# When
iterator = dynamo_identity_wrapper.iter_all_items_paginated(
environment_api_key=environment_api_key,
limit=limit,
capacity_budget=capacity_budget,
)
result_1 = next(iterator)
result_2 = next(iterator)

# Then
with pytest.raises(CapacityBudgetExceeded) as exc_info:
next(iterator)

assert result_1 == identity_document
assert result_2 == identity_document
assert exc_info.value.capacity_budget == capacity_budget
assert exc_info.value.capacity_spent == Decimal("2.2")

mocked_get_all_items.assert_has_calls(
[
mocker.call(
environment_api_key=environment_api_key,
limit=limit,
projection_expression=None,
return_consumed_capacity=True,
),
mocker.call(
environment_api_key=environment_api_key,
limit=limit,
projection_expression=None,
return_consumed_capacity=True,
start_key=expected_next_page_key,
),
]
Expand Down

0 comments on commit 740a702

Please sign in to comment.