Skip to content

Commit

Permalink
Sagemaker: Fix pagination for ModelPackages(Groups) (#6972)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Oct 31, 2023
1 parent aa37700 commit d390aa6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
4 changes: 2 additions & 2 deletions moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageGroupArn",
"unique_attribute": "model_package_group_arn",
},
"list_model_packages": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageArn",
"unique_attribute": "model_package_arn",
},
"list_notebook_instances": {
"input_token": "next_token",
Expand Down
46 changes: 23 additions & 23 deletions tests/test_sagemaker/test_sagemaker_model_package_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,34 @@ def test_create_model_package_group():
@mock_sagemaker
def test_list_model_package_groups():
client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group-1"
desc1 = "test-model-package-group-description-1"
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
ModelPackageGroupName=group1, ModelPackageGroupDescription=desc1
)

group2 = "test-model-package-group-2"
desc2 = "test-model-package-group-description-2"
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
ModelPackageGroupName=group2,
ModelPackageGroupDescription=desc2,
)
resp = client.list_model_package_groups()

assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
== "test-model-package-group-1"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0]
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"]
== "test-model-package-group-description-1"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1]
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"]
== "test-model-package-group-description-2"
)
summary = client.list_model_package_groups()["ModelPackageGroupSummaryList"]

assert summary[0]["ModelPackageGroupName"] == group1
assert summary[0]["ModelPackageGroupDescription"] == desc1

assert summary[1]["ModelPackageGroupName"] == group2
assert summary[1]["ModelPackageGroupDescription"] == desc2

# Pagination
resp = client.list_model_package_groups(MaxResults=1)
assert len(resp["ModelPackageGroupSummaryList"]) == 1

resp = client.list_model_package_groups(MaxResults=1, NextToken=resp["NextToken"])
assert len(resp["ModelPackageGroupSummaryList"]) == 1
assert "NextToken" not in resp


@mock_sagemaker
Expand Down
21 changes: 16 additions & 5 deletions tests/test_sagemaker/test_sagemaker_model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,40 @@ def test_list_model_packages_approval_status():
@mock_sagemaker
def test_list_model_packages_model_package_group_name():
client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group"
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package-without-group",
ModelPackageDescription="test-model-package-description-without-group",
ModelPackageDescription="diff_group",
)
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
resp = client.list_model_packages(ModelPackageGroupName=group1)

assert len(resp["ModelPackageSummaryList"]) == 3

# Pagination
resp = client.list_model_packages(ModelPackageGroupName=group1, MaxResults=2)
assert len(resp["ModelPackageSummaryList"]) == 2

resp = client.list_model_packages(
ModelPackageGroupName=group1, MaxResults=2, NextToken=resp["NextToken"]
)
assert len(resp["ModelPackageSummaryList"]) == 1
assert "NextToken" not in resp


@mock_sagemaker
def test_list_model_packages_model_package_type():
Expand Down

0 comments on commit d390aa6

Please sign in to comment.