Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamoDB: scan() now supports parallelization using the Segment/TotalSegments parameters #8303

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions moto/dynamodb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def scan(
index_name: str,
consistent_read: bool,
projection_expression: Optional[List[List[str]]],
segments: Union[Tuple[None, None], Tuple[int, int]],
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
table = self.get_table(table_name)

Expand All @@ -421,6 +422,7 @@ def scan(
index_name,
consistent_read,
projection_expression,
segments=segments,
)

def update_item(
Expand Down
31 changes: 30 additions & 1 deletion moto/dynamodb/models/dynamo_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import copy
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.utils import merge_dicts
Expand All @@ -12,6 +12,7 @@
IncorrectDataType,
ItemSizeTooLarge,
)
from moto.utilities.utils import md5_hash

from .utilities import bytesize, find_nested_key

Expand Down Expand Up @@ -455,3 +456,31 @@ def project(self, projection_expressions: List[List[str]]) -> "Item":
# We need to convert that into DynamoDB dictionary ({'M': {'key': {'S': 'value'}}})
attrs=serializer.serialize(result)["M"],
)

def is_within_segment(
self, segments: Union[Tuple[None, None], Tuple[int, int]]
) -> bool:
"""
Segments can be either (x, y) or (None, None)
None, None => the user requested the entire table, so the item always falls within that
x, y => the user requested segment x out of y

Segment membership is computed based on the value of the hash key
"""
if segments == (None, None):
return True

segment, total_segments = segments
# Creates a reproducible hash number for this item (between 0 and 256)
# Note that we can't use the builtin hash() method, as that is not deterministic between executions
#
# Using a hash based on the hash key ensures parity with how AWS seems to behave:
# - Items are not divided equally between segment
# - Items always fall in the same segment, regardless of how often you call `scan()`
# - Items with the same hash key but different range keys always fall in the same segment
# - Items with different hash keys may be part of different segments
#
item_hash = md5_hash(self.hash_key.value.encode("utf8")).digest()[0]
# Modulo ensures that we always get a number between 0 and (total_segments)
item_segment = item_hash % total_segments
return segment == item_segment
6 changes: 5 additions & 1 deletion moto/dynamodb/models/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from moto.core.common_models import BaseModel, CloudFormationModel
from moto.core.utils import unix_time, unix_time_millis, utcnow
Expand Down Expand Up @@ -897,6 +897,7 @@ def scan(
index_name: Optional[str] = None,
consistent_read: bool = False,
projection_expression: Optional[List[List[str]]] = None,
segments: Union[Tuple[None, None], Tuple[int, int]] = (None, None),
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results: List[Item] = []
result_size = 0
Expand Down Expand Up @@ -942,6 +943,9 @@ def scan(
last_evaluated_key = None
processing_previous_page = exclusive_start_key is not None
for item in items:
if not item.is_within_segment(segments):
continue

# Cycle through the previous page of results
# When we encounter our start key, we know we've reached the end of the previous page
if processing_previous_page:
Expand Down
31 changes: 25 additions & 6 deletions moto/dynamodb/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,24 @@ def scan(self) -> str:
limit = self.body.get("Limit")
index_name = self.body.get("IndexName")
consistent_read = self.body.get("ConsistentRead", False)
segment = self.body.get("Segment")
total_segments = self.body.get("TotalSegments")
if segment is not None and total_segments is None:
raise MockValidationException(
"The TotalSegments parameter is required but was not present in the request when Segment parameter is present"
)
if total_segments is not None and segment is None:
raise MockValidationException(
"The Segment parameter is required but was not present in the request when parameter TotalSegments is present"
)
if (
segment is not None
and total_segments is not None
and segment >= total_segments
):
raise MockValidationException(
f"The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: {segment} is not less than TotalSegments: {total_segments}"
)

projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
Expand All @@ -840,12 +858,13 @@ def scan(self) -> str:
filters,
limit,
exclusive_start_key,
filter_expression,
expression_attribute_names,
expression_attribute_values,
index_name,
consistent_read,
projection_expressions,
filter_expression=filter_expression,
expr_names=expression_attribute_names,
expr_values=expression_attribute_values,
index_name=index_name,
consistent_read=consistent_read,
projection_expression=projection_expressions,
segments=(segment, total_segments),
)
except ValueError as err:
raise MockValidationException(f"Bad Filter Expression: {err}")
Expand Down
30 changes: 24 additions & 6 deletions tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,40 @@

class BaseTest:
@classmethod
def setup_class(cls):
def setup_class(cls, add_range=False):
if not allow_aws_request():
cls.mock = mock_aws()
cls.mock.start()
cls.client = boto3.client("dynamodb", region_name="us-east-1")
cls.table_name = "T" + str(uuid4())[0:6]
cls.has_range_key = add_range

dynamodb = boto3.resource("dynamodb", region_name="us-east-1")

# Create the DynamoDB table.
schema = [{"AttributeName": "pk", "KeyType": "HASH"}]
defs = [{"AttributeName": "pk", "AttributeType": "S"}]
if add_range:
schema.append({"AttributeName": "rk", "KeyType": "RANGE"})
defs.append({"AttributeName": "rk", "AttributeType": "S"})
dynamodb.create_table(
TableName=cls.table_name,
KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}],
KeySchema=schema,
AttributeDefinitions=defs,
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
waiter = cls.client.get_waiter("table_exists")
waiter.wait(TableName=cls.table_name)
cls.table = dynamodb.Table(cls.table_name)
cls.table.put_item(
Item={"pk": "the-key", "subject": "123", "body": "some test msg"}
)

def setup_method(self):
# Empty table between runs
items = self.table.scan()["Items"]
for item in items:
if self.has_range_key:
self.table.delete_item(Key={"pk": item["pk"], "rk": item["rk"]})
else:
self.table.delete_item(Key={"pk": item["pk"]})

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -1296,6 +1308,12 @@ def test_query_with_missing_expression_attribute():

@pytest.mark.aws_verified
class TestReturnValuesOnConditionCheckFailure(BaseTest):
def setup_method(self):
super().setup_method()
self.table.put_item(
Item={"pk": "the-key", "subject": "123", "body": "some test msg"}
)

def test_put_item_does_not_return_old_item(self):
with pytest.raises(ClientError) as exc:
self.table.put_item(
Expand Down
125 changes: 125 additions & 0 deletions tests/test_dynamodb/test_dynamodb_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from botocore.exceptions import ClientError

from moto import mock_aws
from tests.test_dynamodb.exceptions.test_dynamodb_exceptions import BaseTest

from . import dynamodb_aws_verified

Expand Down Expand Up @@ -729,3 +730,127 @@ def test_scan_with_scanfilter(self):
"Items"
]
assert items == [{"partitionKey": "pk-1"}]


@pytest.mark.aws_verified
class TestParallelScan(BaseTest):
@staticmethod
def setup_class(cls): # pylint: disable=arguments-renamed
super().setup_class(add_range=True)

def test_segment_only(self):
with pytest.raises(ClientError) as exc:
self.table.scan(Segment=1)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The TotalSegments parameter is required but was not present in the request when Segment parameter is present"
)

def test_total_segments_only(self):
with pytest.raises(ClientError) as exc:
self.table.scan(TotalSegments=1)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The Segment parameter is required but was not present in the request when parameter TotalSegments is present"
)

def test_parallelize_all_different_hash_keys(self):
for i in range(10):
self.table.put_item(Item={"pk": f"item{i}", "rk": "sth"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 10

def test_parallelize_different_hash_key_per_segment(self):
for i in range(3):
for j in range(4):
self.table.put_item(Item={"pk": f"item{i}", "rk": f"rk{j}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 12

def test_scan_using_filter_expression(self):
# AWS seems to return all data in Segment 1
for i in range(10):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})
for i in range(10):
self.table.put_item(Item={"pk": "n/a", "rk": f"range{i}"})
for i in range(20, 10, -1):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=0, TotalSegments=3
)["Items"]
resp2 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=1, TotalSegments=3
)["Items"]
resp3 = self.table.scan(
FilterExpression=Attr("pk").eq("item"), Segment=2, TotalSegments=3
)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 20

def test_scan_single_hash_key(self):
# AWS seems to return all data in Segment 1
for i in range(10):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})
for i in range(20, 10, -1):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"]
resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"]
resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"]

assert len(resp1) + len(resp2) + len(resp3) == 20

def test_pagination(self):
for i in range(50):
self.table.put_item(Item={"pk": "item", "rk": f"range{i}"})

resp1 = self.table.scan(Segment=0, TotalSegments=3, Limit=10)
resp2 = self.table.scan(Segment=1, TotalSegments=3, Limit=10)
resp3 = self.table.scan(Segment=2, TotalSegments=3, Limit=10)

first_pass = len(resp1["Items"]) + len(resp2["Items"]) + len(resp3["Items"])
assert first_pass <= 30

second_pass = 0
if "LastEvaluatedKey" in resp1:
resp = self.table.scan(
Segment=0, TotalSegments=3, ExclusiveStartKey=resp1["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

if "LastEvaluatedKey" in resp2:
resp = self.table.scan(
Segment=1, TotalSegments=3, ExclusiveStartKey=resp2["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

if "LastEvaluatedKey" in resp3:
resp = self.table.scan(
Segment=2, TotalSegments=3, ExclusiveStartKey=resp3["LastEvaluatedKey"]
)
second_pass += len(resp["Items"])

assert first_pass + second_pass == 50

def test_segment_larger_than_total_segments(self):
with pytest.raises(ClientError) as exc:
self.table.scan(Segment=3, TotalSegments=3)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== "The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: 3 is not less than TotalSegments: 3"
)
Loading