From f58a575d821387848a27ce6e40edef676e574b6b Mon Sep 17 00:00:00 2001 From: Josh Levy Date: Sat, 27 Apr 2024 15:52:04 -0400 Subject: [PATCH] Feature: Bedrock and Bedrock Agents (#7597) --- moto/backend_index.py | 2 + moto/backends.py | 10 + moto/bedrock/__init__.py | 1 + moto/bedrock/exceptions.py | 33 + moto/bedrock/models.py | 534 +++++++ moto/bedrock/responses.py | 186 +++ moto/bedrock/urls.py | 29 + moto/bedrockagent/__init__.py | 1 + moto/bedrockagent/exceptions.py | 26 + moto/bedrockagent/models.py | 322 ++++ moto/bedrockagent/responses.py | 153 ++ moto/bedrockagent/urls.py | 11 + tests/test_bedrock/__init__.py | 0 tests/test_bedrock/test_bedrock.py | 1401 ++++++++++++++++++ tests/test_bedrockagent/__init__.py | 0 tests/test_bedrockagent/test_bedrockagent.py | 726 +++++++++ 16 files changed, 3435 insertions(+) create mode 100644 moto/bedrock/__init__.py create mode 100644 moto/bedrock/exceptions.py create mode 100644 moto/bedrock/models.py create mode 100644 moto/bedrock/responses.py create mode 100644 moto/bedrock/urls.py create mode 100644 moto/bedrockagent/__init__.py create mode 100644 moto/bedrockagent/exceptions.py create mode 100644 moto/bedrockagent/models.py create mode 100644 moto/bedrockagent/responses.py create mode 100644 moto/bedrockagent/urls.py create mode 100644 tests/test_bedrock/__init__.py create mode 100644 tests/test_bedrock/test_bedrock.py create mode 100644 tests/test_bedrockagent/__init__.py create mode 100644 tests/test_bedrockagent/test_bedrockagent.py diff --git a/moto/backend_index.py b/moto/backend_index.py index 23bbae341563..a1952b1f349f 100644 --- a/moto/backend_index.py +++ b/moto/backend_index.py @@ -22,6 +22,8 @@ ("backup", re.compile("https?://backup\\.(.+)\\.amazonaws\\.com")), ("batch", re.compile("https?://batch\\.(.+)\\.amazonaws.com")), ("budgets", re.compile("https?://budgets\\.amazonaws\\.com")), + ("bedrock", re.compile("https?://bedrock\\.(.+)\\.amazonaws\\.com")), + ("bedrockagent", re.compile("https?://bedrock-agent\\.(.+)\\.amazonaws\\.com")), ("ce", re.compile("https?://ce\\.(.+)\\.amazonaws\\.com")), ("cloudformation", re.compile("https?://cloudformation\\.(.+)\\.amazonaws\\.com")), ("cloudfront", re.compile("https?://cloudfront\\.amazonaws\\.com")), diff --git a/moto/backends.py b/moto/backends.py index 7a5d381c0bf6..3e5a5a10df50 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -20,6 +20,8 @@ from moto.autoscaling.models import AutoScalingBackend from moto.awslambda.models import LambdaBackend from moto.batch.models import BatchBackend + from moto.bedrock.models import BedrockBackend + from moto.bedrockagent.models import AgentsforBedrockBackend from moto.budgets.models import BudgetsBackend from moto.ce.models import CostExplorerBackend from moto.cloudformation.models import CloudFormationBackend @@ -182,6 +184,8 @@ def get_service_from_url(url: str) -> Optional[str]: "Literal['athena']", "Literal['autoscaling']", "Literal['batch']", + "Literal['bedrock']", + "Literal['bedrock-agent']", "Literal['budgets']", "Literal['ce']", "Literal['cloudformation']", @@ -344,6 +348,12 @@ def get_backend( @overload def get_backend(name: "Literal['batch']") -> "BackendDict[BatchBackend]": ... @overload +def get_backend(name: "Literal['bedrock']") -> "BackendDict[BedrockBackend]": ... +@overload +def get_backend( + name: "Literal['bedrock-agent']", +) -> "BackendDict[AgentsforBedrockBackend]": ... +@overload def get_backend(name: "Literal['budgets']") -> "BackendDict[BudgetsBackend]": ... @overload def get_backend(name: "Literal['ce']") -> "BackendDict[CostExplorerBackend]": ... diff --git a/moto/bedrock/__init__.py b/moto/bedrock/__init__.py new file mode 100644 index 000000000000..f203809b1f11 --- /dev/null +++ b/moto/bedrock/__init__.py @@ -0,0 +1 @@ +from .models import bedrock_backends # noqa: F401 diff --git a/moto/bedrock/exceptions.py b/moto/bedrock/exceptions.py new file mode 100644 index 000000000000..9745d59b0976 --- /dev/null +++ b/moto/bedrock/exceptions.py @@ -0,0 +1,33 @@ +"""Exceptions raised by the bedrock service.""" + +from moto.core.exceptions import JsonRESTError + +# Bedrock.Client.exceptions.ResourceNotFoundException + + +class BedrockClientError(JsonRESTError): + code = 400 + + +class ResourceNotFoundException(BedrockClientError): + def __init__(self, msg: str): + super().__init__("ResourceNotFoundException", f"{msg}") + + +class ResourceInUseException(BedrockClientError): + def __init__(self, msg: str): + super().__init__("ResourceInUseException", f"{msg}") + + +class ValidationException(BedrockClientError): + def __init__(self, msg: str): + super().__init__( + "ValidationException", + "Input validation failed. Check your request parameters and retry the request.", + f"{msg}", + ) + + +class TooManyTagsException(BedrockClientError): + def __init__(self, msg: str): + super().__init__("TooManyTagsException", f"{msg}") diff --git a/moto/bedrock/models.py b/moto/bedrock/models.py new file mode 100644 index 000000000000..6023b1b5528f --- /dev/null +++ b/moto/bedrock/models.py @@ -0,0 +1,534 @@ +"""BedrockBackend class with methods for supported APIs.""" + +import re +from datetime import datetime +from typing import Any, Dict, List, Optional + +from moto.bedrock.exceptions import ( + ResourceInUseException, + ResourceNotFoundException, + TooManyTagsException, + ValidationException, +) +from moto.core.base_backend import BackendDict, BaseBackend +from moto.core.common_models import BaseModel +from moto.utilities.paginator import paginate +from moto.utilities.tagging_service import TaggingService + + +class ModelCustomizationJob(BaseModel): + def __init__( + self, + job_name: str, + custom_model_name: str, + role_arn: str, + base_model_identifier: str, + training_data_config: Dict[str, str], + output_data_config: Dict[str, str], + hyper_parameters: Dict[str, str], + region_name: str, + account_id: str, + client_request_token: Optional[str], + customization_type: Optional[str], + custom_model_kms_key_id: Optional[str], + job_tags: Optional[List[Dict[str, str]]], + custom_model_tags: Optional[List[Dict[str, str]]], + validation_data_config: Optional[Dict[str, Any]], + vpc_config: Optional[Dict[str, Any]], + ): + self.job_name = job_name + self.custom_model_name = custom_model_name + self.role_arn = role_arn + self.client_request_token = client_request_token + self.base_model_identifier = base_model_identifier + self.customization_type = customization_type + self.custom_model_kms_key_id = custom_model_kms_key_id + self.job_tags = job_tags + self.custom_model_tags = custom_model_tags + if "s3Uri" not in training_data_config or not re.match( + r"s3://.*", training_data_config["s3Uri"] + ): + raise ValidationException( + "Validation error detected: " + f"Value '{training_data_config}' at 'training_data_config' failed to satisfy constraint: " + "Member must satisfy regular expression pattern: " + "s3://.*" + ) + self.training_data_config = training_data_config + if validation_data_config: + if "validators" in validation_data_config: + for validator in validation_data_config["validators"]: + if not re.match(r"s3://.*", validator["s3Uri"]): + raise ValidationException( + "Validation error detected: " + f"Value '{validator}' at 'validation_data_config' failed to satisfy constraint: " + "Member must satisfy regular expression pattern: " + "s3://.*" + ) + self.validation_data_config = validation_data_config + if "s3Uri" not in output_data_config or not re.match( + r"s3://.*", output_data_config["s3Uri"] + ): + raise ValidationException( + "Validation error detected: " + f"Value '{output_data_config}' at 'output_data_config' failed to satisfy constraint: " + "Member must satisfy regular expression pattern: " + "s3://.*" + ) + self.output_data_config = output_data_config + self.hyper_parameters = hyper_parameters + self.vpc_config = vpc_config + self.region_name = region_name + self.account_id = account_id + self.job_arn = f"arn:aws:bedrock:{self.region_name}:{self.account_id}:model-customization-job/{self.job_name}" + self.output_model_name = f"{self.custom_model_name}-{self.job_name}" + self.output_model_arn = f"arn:aws:bedrock:{self.region_name}:{self.account_id}:custom-model/{self.output_model_name}" + self.status = "InProgress" + self.failure_message = "Failure Message" + self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.base_model_arn = f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.base_model_identifier}" + self.output_model_kms_key_arn = f"arn:aws:kms:{self.region_name}:{self.account_id}:key/{self.output_model_name}-kms-key" + self.training_metrics = {"trainingLoss": 0.0} # hard coded + self.validation_metrics = [{"validationLoss": 0.0}] # hard coded + + def to_dict(self) -> Dict[str, Any]: + dct = { + "baseModelArn": self.base_model_arn, + "clientRequestToken": self.client_request_token, + "creationTime": self.creation_time, + "customizationType": self.customization_type, + "endTime": self.end_time, + "failureMessage": self.failure_message, + "hyperParameters": self.hyper_parameters, + "jobArn": self.job_arn, + "jobName": self.job_name, + "lastModifiedTime": self.last_modified_time, + "outputDataConfig": self.output_data_config, + "outputModelArn": self.output_model_arn, + "outputModelKmsKeyArn": self.output_model_kms_key_arn, + "outputModelName": self.output_model_name, + "roleArn": self.role_arn, + "status": self.status, + "trainingDataConfig": self.training_data_config, + "trainingMetrics": self.training_metrics, + "validationDataConfig": self.validation_data_config, + "validationMetrics": self.validation_metrics, + "vpcConfig": self.vpc_config, + } + return {k: v for k, v in dct.items() if v} + + +class CustomModel(BaseModel): + def __init__( + self, + model_name: str, + job_name: str, + job_arn: str, + base_model_arn: str, + hyper_parameters: Dict[str, str], + output_data_config: Dict[str, str], + training_data_config: Dict[str, str], + training_metrics: Dict[str, float], + base_model_name: str, + region_name: str, + account_id: str, + customization_type: Optional[str], + model_kms_key_arn: Optional[str], + validation_data_config: Optional[Dict[str, Any]], + validation_metrics: Optional[List[Dict[str, float]]], + ): + self.model_name = model_name + self.job_name = job_name + self.job_arn = job_arn + self.base_model_arn = base_model_arn + self.customization_type = customization_type + self.model_kms_key_arn = model_kms_key_arn + self.hyper_parameters = hyper_parameters + self.training_data_config = training_data_config + self.validation_data_config = validation_data_config + self.output_data_config = output_data_config + self.training_metrics = training_metrics + self.validation_metrics = validation_metrics + self.region_name = region_name + self.account_id = account_id + self.model_arn = f"arn:aws:bedrock:{self.region_name}:{self.account_id}:custom-model/{self.model_name}" + self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.base_model_name = base_model_name + + def to_dict(self) -> Dict[str, Any]: + dct = { + "baseModelArn": self.base_model_arn, + "creationTime": self.creation_time, + "customizationType": self.customization_type, + "hyperParameters": self.hyper_parameters, + "jobArn": self.job_arn, + "jobName": self.job_name, + "modelArn": self.model_arn, + "modelKmsKeyArn": self.model_kms_key_arn, + "modelName": self.model_name, + "outputDataConfig": self.output_data_config, + "trainingDataConfig": self.training_data_config, + "trainingMetrics": self.training_metrics, + "validationDataConfig": self.validation_data_config, + "validationMetrics": self.validation_metrics, + } + return {k: v for k, v in dct.items() if v} + + +class model_invocation_logging_configuration(BaseModel): + def __init__(self, logging_config: Dict[str, Any]) -> None: + self.logging_config = logging_config + + +class BedrockBackend(BaseBackend): + """Implementation of Bedrock APIs.""" + + PAGINATION_MODEL = { + "list_model_customization_jobs": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "jobArn", + }, + "list_custom_models": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "modelArn", + }, + } + + def __init__(self, region_name: str, account_id: str) -> None: + super().__init__(region_name, account_id) + self.model_customization_jobs: Dict[str, ModelCustomizationJob] = {} + self.custom_models: Dict[str, CustomModel] = {} + self.model_invocation_logging_configuration: Optional[ + model_invocation_logging_configuration + ] = None + self.tagger = TaggingService() + + def _list_arns(self) -> List[str]: + return [job.job_arn for job in self.model_customization_jobs.values()] + [ + model.model_arn for model in self.custom_models.values() + ] + + def create_model_customization_job( + self, + job_name: str, + custom_model_name: str, + role_arn: str, + base_model_identifier: str, + training_data_config: Dict[str, Any], + output_data_config: Dict[str, str], + hyper_parameters: Dict[str, str], + client_request_token: Optional[str], + customization_type: Optional[str], + custom_model_kms_key_id: Optional[str], + job_tags: Optional[List[Dict[str, str]]], + custom_model_tags: Optional[List[Dict[str, str]]], + validation_data_config: Optional[Dict[str, Any]], + vpc_config: Optional[Dict[str, Any]], + ) -> str: + if job_name in self.model_customization_jobs.keys(): + raise ResourceInUseException( + f"Model customization job {job_name} already exists" + ) + if custom_model_name in self.custom_models.keys(): + raise ResourceInUseException( + f"Custom model {custom_model_name} already exists" + ) + model_customization_job = ModelCustomizationJob( + job_name, + custom_model_name, + role_arn, + base_model_identifier, + training_data_config, + output_data_config, + hyper_parameters, + self.region_name, + self.account_id, + client_request_token, + customization_type, + custom_model_kms_key_id, + job_tags, + custom_model_tags, + validation_data_config, + vpc_config, + ) + self.model_customization_jobs[job_name] = model_customization_job + if job_tags: + self.tag_resource(model_customization_job.job_arn, job_tags) + # Create associated custom model + custom_model = CustomModel( + custom_model_name, + job_name, + model_customization_job.job_arn, + model_customization_job.base_model_arn, + model_customization_job.hyper_parameters, + model_customization_job.output_data_config, + model_customization_job.training_data_config, + model_customization_job.training_metrics, + model_customization_job.base_model_identifier, + self.region_name, + self.account_id, + model_customization_job.customization_type, + model_customization_job.output_model_kms_key_arn, + model_customization_job.validation_data_config, + model_customization_job.validation_metrics, + ) + self.custom_models[custom_model_name] = custom_model + if custom_model_tags: + self.tag_resource(custom_model.model_arn, custom_model_tags) + return model_customization_job.job_arn + + def get_model_customization_job(self, job_identifier: str) -> ModelCustomizationJob: + if job_identifier not in self.model_customization_jobs: + raise ResourceNotFoundException( + f"Model customization job {job_identifier} not found" + ) + else: + return self.model_customization_jobs[job_identifier] + + def stop_model_customization_job(self, job_identifier: str) -> None: + if job_identifier in self.model_customization_jobs: + self.model_customization_jobs[job_identifier].status = "Stopped" + else: + raise ResourceNotFoundException( + f"Model customization job {job_identifier} not found" + ) + return + + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_model_customization_jobs( + self, + creation_time_after: Optional[datetime], + creation_time_before: Optional[datetime], + status_equals: Optional[str], + name_contains: Optional[str], + max_results: Optional[int], + next_token: Optional[str], + sort_by: Optional[str], + sort_order: Optional[str], + ) -> List[Any]: + customization_jobs_fetched = list(self.model_customization_jobs.values()) + + if name_contains is not None: + customization_jobs_fetched = list( + filter( + lambda x: name_contains in x.job_name, + customization_jobs_fetched, + ) + ) + + if creation_time_after is not None: + customization_jobs_fetched = list( + filter( + lambda x: x.creation_time > str(creation_time_after), + customization_jobs_fetched, + ) + ) + + if creation_time_before is not None: + customization_jobs_fetched = list( + filter( + lambda x: x.creation_time < str(creation_time_before), + customization_jobs_fetched, + ) + ) + if status_equals is not None: + customization_jobs_fetched = list( + filter( + lambda x: x.status == status_equals, + customization_jobs_fetched, + ) + ) + + if sort_by is not None: + if sort_by == "CreationTime": + if sort_order is not None and sort_order == "Ascending": + customization_jobs_fetched = sorted( + customization_jobs_fetched, key=lambda x: x.creation_time + ) + elif sort_order is not None and sort_order == "Descending": + customization_jobs_fetched = sorted( + customization_jobs_fetched, + key=lambda x: x.creation_time, + reverse=True, + ) + else: + raise ValidationException(f"Invalid sort order: {sort_order}") + else: + raise ValidationException(f"Invalid sort by field: {sort_by}") + + model_customization_job_summaries = [] + for model in customization_jobs_fetched: + model_customization_job_summaries.append( + { + "jobArn": model.job_arn, + "baseModelArn": model.base_model_arn, + "jobName": model.job_name, + "status": model.status, + "lastModifiedTime": model.last_modified_time, + "creationTime": model.creation_time, + "endTime": model.end_time, + "customModelArn": model.output_model_arn, + "customModelName": model.custom_model_name, + "customizationType": model.customization_type, + } + ) + return model_customization_job_summaries + + def get_model_invocation_logging_configuration(self) -> Optional[Dict[str, Any]]: + if self.model_invocation_logging_configuration: + return self.model_invocation_logging_configuration.logging_config + else: + return {} + + def put_model_invocation_logging_configuration( + self, logging_config: Dict[str, Any] + ) -> None: + invocation_logging = model_invocation_logging_configuration(logging_config) + self.model_invocation_logging_configuration = invocation_logging + return + + def get_custom_model(self, model_identifier: str) -> CustomModel: + if model_identifier[:3] == "arn": + for model in self.custom_models.values(): + if model.model_arn == model_identifier: + return model + raise ResourceNotFoundException( + f"Custom model {model_identifier} not found" + ) + elif model_identifier in self.custom_models: + return self.custom_models[model_identifier] + else: + raise ResourceNotFoundException( + f"Custom model {model_identifier} not found" + ) + + def delete_custom_model(self, model_identifier: str) -> None: + if model_identifier in self.custom_models: + del self.custom_models[model_identifier] + else: + raise ResourceNotFoundException( + f"Custom model {model_identifier} not found" + ) + return + + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_custom_models( + self, + creation_time_before: Optional[datetime], + creation_time_after: Optional[datetime], + name_contains: Optional[str], + base_model_arn_equals: Optional[str], + foundation_model_arn_equals: Optional[str], + max_results: Optional[int], + next_token: Optional[str], + sort_by: Optional[str], + sort_order: Optional[str], + ) -> List[Any]: + """ + The foundation_model_arn_equals-argument is not yet supported + """ + custom_models_fetched = list(self.custom_models.values()) + + if name_contains is not None: + custom_models_fetched = list( + filter( + lambda x: name_contains in x.job_name, + custom_models_fetched, + ) + ) + + if creation_time_after is not None: + custom_models_fetched = list( + filter( + lambda x: x.creation_time > str(creation_time_after), + custom_models_fetched, + ) + ) + + if creation_time_before is not None: + custom_models_fetched = list( + filter( + lambda x: x.creation_time < str(creation_time_before), + custom_models_fetched, + ) + ) + if base_model_arn_equals is not None: + custom_models_fetched = list( + filter( + lambda x: x.base_model_arn == base_model_arn_equals, + custom_models_fetched, + ) + ) + + if sort_by is not None: + if sort_by == "CreationTime": + if sort_order is not None and sort_order == "Ascending": + custom_models_fetched = sorted( + custom_models_fetched, key=lambda x: x.creation_time + ) + elif sort_order is not None and sort_order == "Descending": + custom_models_fetched = sorted( + custom_models_fetched, + key=lambda x: x.creation_time, + reverse=True, + ) + else: + raise ValidationException(f"Invalid sort order: {sort_order}") + else: + raise ValidationException(f"Invalid sort by field: {sort_by}") + model_summaries = [] + for model in custom_models_fetched: + model_summaries.append( + { + "modelArn": model.model_arn, + "modelName": model.model_name, + "creationTime": model.creation_time, + "baseModelArn": model.base_model_arn, + "baseModelName": model.base_model_name, + "jobArn": model.job_arn, + "customizationType": model.customization_type, + } + ) + return model_summaries + + def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + fixed_tags = [] + if len(tags) + len(self.tagger.list_tags_for_resource(resource_arn)) > 50: + raise TooManyTagsException( + "Member must have length less than or equal to 50" + ) + for tag_dict in tags: + fixed_tags.append({"Key": tag_dict["key"], "Value": tag_dict["value"]}) + self.tagger.tag_resource(resource_arn, fixed_tags) + return + + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + self.tagger.untag_resource_using_names(resource_arn, tag_keys) + return + + def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + tags = self.tagger.list_tags_for_resource(resource_arn) + fixed_tags = [] + for tag_dict in tags["Tags"]: + fixed_tags.append({"key": tag_dict["Key"], "value": tag_dict["Value"]}) + return fixed_tags + + def delete_model_invocation_logging_configuration(self) -> None: + if self.model_invocation_logging_configuration: + self.model_invocation_logging_configuration.logging_config = {} + return + + +bedrock_backends = BackendDict(BedrockBackend, "bedrock") diff --git a/moto/bedrock/responses.py b/moto/bedrock/responses.py new file mode 100644 index 000000000000..c6adda6356b6 --- /dev/null +++ b/moto/bedrock/responses.py @@ -0,0 +1,186 @@ +"""Handles incoming bedrock requests, invokes methods, returns responses.""" + +import json +from urllib.parse import unquote + +from moto.core.responses import BaseResponse + +from .models import BedrockBackend, bedrock_backends + + +class BedrockResponse(BaseResponse): + """Handler for Bedrock requests and responses.""" + + def __init__(self) -> None: + super().__init__(service_name="bedrock") + + @property + def bedrock_backend(self) -> BedrockBackend: + """Return backend instance specific for this region.""" + return bedrock_backends[self.current_account][self.region] + + def create_model_customization_job(self) -> str: + params = json.loads(self.body) + job_name = params.get("jobName") + custom_model_name = params.get("customModelName") + role_arn = params.get("roleArn") + client_request_token = params.get("clientRequestToken") + base_model_identifier = params.get("baseModelIdentifier") + customization_type = params.get("customizationType") + custom_model_kms_key_id = params.get("customModelKmsKeyId") + job_tags = params.get("jobTags") + custom_model_tags = params.get("customModelTags") + training_data_config = params.get("trainingDataConfig") + validation_data_config = params.get("validationDataConfig") + output_data_config = params.get("outputDataConfig") + hyper_parameters = params.get("hyperParameters") + vpc_config = params.get("vpcConfig") + job_arn = self.bedrock_backend.create_model_customization_job( + job_name=job_name, + custom_model_name=custom_model_name, + role_arn=role_arn, + client_request_token=client_request_token, + base_model_identifier=base_model_identifier, + customization_type=customization_type, + custom_model_kms_key_id=custom_model_kms_key_id, + job_tags=job_tags, + custom_model_tags=custom_model_tags, + training_data_config=training_data_config, + validation_data_config=validation_data_config, + output_data_config=output_data_config, + hyper_parameters=hyper_parameters, + vpc_config=vpc_config, + ) + return json.dumps(dict(jobArn=job_arn)) + + def get_model_customization_job(self) -> str: + job_identifier = self.path.split("/")[-1] + model_customization_job = self.bedrock_backend.get_model_customization_job( + job_identifier=job_identifier + ) + return json.dumps(dict(model_customization_job.to_dict())) + + def get_model_invocation_logging_configuration(self) -> str: + logging_config = ( + self.bedrock_backend.get_model_invocation_logging_configuration() + ) + return json.dumps(dict(loggingConfig=logging_config)) + + def put_model_invocation_logging_configuration(self) -> None: + params = json.loads(self.body) + logging_config = params.get("loggingConfig") + self.bedrock_backend.put_model_invocation_logging_configuration( + logging_config=logging_config + ) + return + + def tag_resource(self) -> None: + params = json.loads(self.body) + resource_arn = params.get("resourceARN") + tags = params.get("tags") + self.bedrock_backend.tag_resource( + resource_arn=resource_arn, + tags=tags, + ) + return + + def untag_resource(self) -> str: + params = json.loads(self.body) + resource_arn = params.get("resourceARN") + tag_keys = params.get("tagKeys") + self.bedrock_backend.untag_resource( + resource_arn=resource_arn, + tag_keys=tag_keys, + ) + return json.dumps(dict()) + + def list_tags_for_resource(self) -> str: + params = json.loads(self.body) + resource_arn = params.get("resourceARN") + tags = self.bedrock_backend.list_tags_for_resource( + resource_arn=resource_arn, + ) + return json.dumps(dict(tags=tags)) + + def get_custom_model(self) -> str: + model_identifier = unquote(self.path.split("/")[-1]) + custom_model = self.bedrock_backend.get_custom_model( + model_identifier=model_identifier + ) + return json.dumps(dict(custom_model.to_dict())) + + def list_custom_models(self) -> str: + params = self._get_params() + creation_time_before = params.get("creationTimeBefore") + creation_time_after = params.get("creationTimeAfter") + name_contains = params.get("nameContains") + base_model_arn_equals = params.get("baseModelArnEquals") + foundation_model_arn_equals = params.get("foundationModelArnEquals") + max_results = params.get("maxResults") + next_token = params.get("nextToken") + sort_by = params.get("sortBy") + sort_order = params.get("sortOrder") + + max_results = int(max_results) if max_results else None + model_summaries, next_token = self.bedrock_backend.list_custom_models( + creation_time_before=creation_time_before, + creation_time_after=creation_time_after, + name_contains=name_contains, + base_model_arn_equals=base_model_arn_equals, + foundation_model_arn_equals=foundation_model_arn_equals, + max_results=max_results, + next_token=next_token, + sort_by=sort_by, + sort_order=sort_order, + ) + return json.dumps(dict(nextToken=next_token, modelSummaries=model_summaries)) + + def list_model_customization_jobs(self) -> str: + params = self._get_params() + creation_time_after = params.get("creationTimeAfter") + creation_time_before = params.get("creationTimeBefore") + status_equals = params.get("statusEquals") + name_contains = params.get("nameContains") + max_results = params.get("maxResults") + next_token = params.get("nextToken") + sort_by = params.get("sortBy") + sort_order = params.get("sortOrder") + + max_results = int(max_results) if max_results else None + ( + model_customization_job_summaries, + next_token, + ) = self.bedrock_backend.list_model_customization_jobs( + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + status_equals=status_equals, + name_contains=name_contains, + max_results=max_results, + next_token=next_token, + sort_by=sort_by, + sort_order=sort_order, + ) + return json.dumps( + dict( + nextToken=next_token, + modelCustomizationJobSummaries=model_customization_job_summaries, + ) + ) + + def delete_custom_model(self) -> str: + model_identifier = self.path.split("/")[-1] + self.bedrock_backend.delete_custom_model( + model_identifier=model_identifier, + ) + return json.dumps(dict()) + + def stop_model_customization_job(self) -> str: + job_identifier = self.path.split("/")[-2] + self.bedrock_backend.stop_model_customization_job( + job_identifier=job_identifier, + ) + return json.dumps(dict()) + + def delete_model_invocation_logging_configuration(self) -> str: + self.bedrock_backend.delete_model_invocation_logging_configuration() + return json.dumps(dict()) diff --git a/moto/bedrock/urls.py b/moto/bedrock/urls.py new file mode 100644 index 000000000000..079762694136 --- /dev/null +++ b/moto/bedrock/urls.py @@ -0,0 +1,29 @@ +"""bedrock base URL and path.""" + +from ..bedrockagent.responses import AgentsforBedrockResponse +from .responses import BedrockResponse + +url_bases = [ + r"https?://bedrock\.(.+)\.amazonaws\.com", +] + +url_paths = { + "{0}/.*$": BedrockResponse.dispatch, + "{0}/agents/?$": AgentsforBedrockResponse.dispatch, + "{0}/agents/(?P[^/]+)/$": AgentsforBedrockResponse.dispatch, + "{0}/custom-models$": BedrockResponse.dispatch, + "{0}/custom-models/(?P[^/]+)$": BedrockResponse.dispatch, + "{0}/custom-models/(?P[^/]+)/(?P[^/]+)$": BedrockResponse.dispatch, + "{0}/knowledgebases$": AgentsforBedrockResponse.dispatch, + "{0}/knowledgebases/(?P[^/]+)$": AgentsforBedrockResponse.dispatch, + "{0}/knowledgebases/(?P[^/]+)/$": AgentsforBedrockResponse.dispatch, + "{0}/listTagsForResource$": BedrockResponse.dispatch, + "{0}/logging/modelinvocations$": BedrockResponse.dispatch, + "{0}/model-customization-jobs$": BedrockResponse.dispatch, + "{0}/model-customization-jobs/(?P[^/]+)$": BedrockResponse.dispatch, + "{0}/model-customization-jobs/(?P[^/]+)/stop$": BedrockResponse.dispatch, + "{0}/tags/(?P[^/]+)$": AgentsforBedrockResponse.dispatch, + "{0}/tags/(?P[^/]+)/(?P[^/]+)$": AgentsforBedrockResponse.dispatch, + "{0}/tagResource$": BedrockResponse.dispatch, + "{0}/untagResource$": BedrockResponse.dispatch, +} diff --git a/moto/bedrockagent/__init__.py b/moto/bedrockagent/__init__.py new file mode 100644 index 000000000000..5a019b11bf84 --- /dev/null +++ b/moto/bedrockagent/__init__.py @@ -0,0 +1 @@ +from .models import bedrockagent_backends # noqa: F401 diff --git a/moto/bedrockagent/exceptions.py b/moto/bedrockagent/exceptions.py new file mode 100644 index 000000000000..97d94a14f8a8 --- /dev/null +++ b/moto/bedrockagent/exceptions.py @@ -0,0 +1,26 @@ +"""Exceptions raised by the bedrockagent service.""" + +from moto.core.exceptions import JsonRESTError + + +class AgentsforBedrockClientError(JsonRESTError): + code = 400 + + +class ResourceNotFoundException(AgentsforBedrockClientError): + def __init__(self, msg: str): + super().__init__("ResourceNotFoundException", f"{msg}") + + +class ConflictException(AgentsforBedrockClientError): + def __init__(self, msg: str): + super().__init__("ConflictException", f"{msg}") + + +class ValidationException(AgentsforBedrockClientError): + def __init__(self, msg: str): + super().__init__( + "ValidationException", + "Input validation failed. Check your request parameters and retry the request.", + f"{msg}", + ) diff --git a/moto/bedrockagent/models.py b/moto/bedrockagent/models.py new file mode 100644 index 000000000000..245c23c73887 --- /dev/null +++ b/moto/bedrockagent/models.py @@ -0,0 +1,322 @@ +"""AgentsforBedrockBackend class with methods for supported APIs.""" + +from typing import Any, Dict, List, Optional, Tuple + +from moto.bedrockagent.exceptions import ( + ConflictException, + ResourceNotFoundException, + ValidationException, +) +from moto.core.base_backend import BackendDict, BaseBackend +from moto.core.common_models import BaseModel +from moto.core.utils import unix_time +from moto.moto_api._internal import mock_random +from moto.utilities.paginator import paginate +from moto.utilities.tagging_service import TaggingService + + +class Agent(BaseModel): + def __init__( + self, + agent_name: str, + agent_resource_role_arn: str, + region_name: str, + account_id: str, + client_token: Optional[str], + instruction: Optional[str], + foundation_model: Optional[str], + description: Optional[str], + idle_session_ttl_in_seconds: Optional[int], + customer_encryption_key_arn: Optional[str], + prompt_override_configuration: Optional[Dict[str, Any]], + ): + self.agent_name = agent_name + self.client_token = client_token + self.instruction = instruction + self.foundation_model = foundation_model + self.description = description + self.idle_session_ttl_in_seconds = idle_session_ttl_in_seconds + self.agent_resource_role_arn = agent_resource_role_arn + self.customer_encryption_key_arn = customer_encryption_key_arn + self.prompt_override_configuration = prompt_override_configuration + self.region_name = region_name + self.account_id = account_id + self.created_at = unix_time() + self.updated_at = unix_time() + self.prepared_at = unix_time() + self.agent_status = "PREPARED" + self.agent_id = self.agent_name + str(mock_random.uuid4())[:8] + self.agent_arn = f"arn:aws:bedrock:{self.region_name}:{self.account_id}:agent/{self.agent_id}" + self.agent_version = "1.0" + self.failure_reasons: List[str] = [] + self.recommended_actions = ["action"] + + def to_dict(self) -> Dict[str, Any]: + dct = { + "agentId": self.agent_id, + "agentName": self.agent_name, + "agentArn": self.agent_arn, + "agentVersion": self.agent_version, + "clientToken": self.client_token, + "instruction": self.instruction, + "agentStatus": self.agent_status, + "foundationModel": self.foundation_model, + "description": self.description, + "idleSessionTTLInSeconds": self.idle_session_ttl_in_seconds, + "agentResourceRoleArn": self.agent_resource_role_arn, + "customerEncryptionKeyArn": self.customer_encryption_key_arn, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "preparedAt": self.prepared_at, + "failureReasons": self.failure_reasons, + "recommendedActions": self.recommended_actions, + "promptOverrideConfiguration": self.prompt_override_configuration, + } + return {k: v for k, v in dct.items() if v} + + def dict_summary(self) -> Dict[str, Any]: + dct = { + "agentId": self.agent_id, + "agentName": self.agent_name, + "agentStatus": self.agent_status, + "description": self.description, + "updatedAt": self.updated_at, + "latestAgentVersion": self.agent_version, + } + return {k: v for k, v in dct.items() if v} + + +class KnowledgeBase(BaseModel): + def __init__( + self, + name: str, + role_arn: str, + region_name: str, + account_id: str, + knowledge_base_configuration: Dict[str, Any], + storage_configuration: Dict[str, Any], + client_token: Optional[str], + description: Optional[str], + ): + self.client_token = client_token + self.name = name + self.description = description + self.role_arn = role_arn + if knowledge_base_configuration["type"] != "VECTOR": + raise ValidationException( + "Validation error detected: " + f"Value '{knowledge_base_configuration['type']}' at 'knowledgeBaseConfiguration' failed to satisfy constraint: " + "Member must contain 'type' as 'VECTOR'" + ) + self.knowledge_base_configuration = knowledge_base_configuration + if storage_configuration["type"] not in [ + "OPENSEARCH_SERVERLESS", + "PINECONE", + "REDIS_ENTERPRISE_CLOUD", + "RDS", + ]: + raise ValidationException( + "Validation error detected: " + f"Value '{storage_configuration['type']}' at 'storageConfiguration' failed to satisfy constraint: " + "Member 'type' must be one of: OPENSEARCH_SERVERLESS | PINECONE | REDIS_ENTERPRISE_CLOUD | RDS" + ) + self.storage_configuration = storage_configuration + self.region_name = region_name + self.account_id = account_id + self.knowledge_base_id = self.name + str(mock_random.uuid4())[:8] + self.knowledge_base_arn = f"arn:aws:bedrock:{self.region_name}:{self.account_id}:knowledge-base/{self.knowledge_base_id}" + self.created_at = unix_time() + self.updated_at = unix_time() + self.status = "Active" + self.failure_reasons: List[str] = [] + + def to_dict(self) -> Dict[str, Any]: + dct = { + "knowledgeBaseId": self.knowledge_base_id, + "name": self.name, + "knowledgeBaseArn": self.knowledge_base_arn, + "description": self.description, + "roleArn": self.role_arn, + "knowledgeBaseConfiguration": self.knowledge_base_configuration, + "storageConfiguration": self.storage_configuration, + "status": self.status, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "failureReasons": self.failure_reasons, + } + return {k: v for k, v in dct.items() if v} + + def dict_summary(self) -> Dict[str, Any]: + dct = { + "knowledgeBaseId": self.knowledge_base_id, + "name": self.name, + "description": self.description, + "status": self.status, + "updatedAt": self.updated_at, + } + return {k: v for k, v in dct.items() if v} + + +class AgentsforBedrockBackend(BaseBackend): + """Implementation of AgentsforBedrock APIs.""" + + PAGINATION_MODEL = { + "list_agents": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "agentId", + }, + "list_knowledge_bases": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "knowledgeBaseId", + }, + } + + def __init__(self, region_name: str, account_id: str): + super().__init__(region_name, account_id) + self.agents: Dict[str, Agent] = {} + self.knowledge_bases: Dict[str, KnowledgeBase] = {} + self.tagger = TaggingService() + + def _list_arns(self) -> List[str]: + return [agent.agent_arn for agent in self.agents.values()] + [ + knowledge_base.knowledge_base_arn + for knowledge_base in self.knowledge_bases.values() + ] + + def create_agent( + self, + agent_name: str, + agent_resource_role_arn: str, + client_token: Optional[str], + instruction: Optional[str], + foundation_model: Optional[str], + description: Optional[str], + idle_session_ttl_in_seconds: Optional[int], + customer_encryption_key_arn: Optional[str], + tags: Optional[Dict[str, str]], + prompt_override_configuration: Optional[Dict[str, Any]], + ) -> Agent: + agent = Agent( + agent_name, + agent_resource_role_arn, + self.region_name, + self.account_id, + client_token, + instruction, + foundation_model, + description, + idle_session_ttl_in_seconds, + customer_encryption_key_arn, + prompt_override_configuration, + ) + self.agents[agent.agent_id] = agent + if tags: + self.tag_resource(agent.agent_arn, tags) + return agent + + def get_agent(self, agent_id: str) -> Agent: + if agent_id not in self.agents: + raise ResourceNotFoundException(f"Agent {agent_id} not found") + return self.agents[agent_id] + + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_agents( + self, max_results: Optional[int], next_token: Optional[str] + ) -> List[Any]: + agent_summaries = [agent.dict_summary() for agent in self.agents.values()] + return agent_summaries + + def delete_agent( + self, agent_id: str, skip_resource_in_use_check: Optional[bool] + ) -> Tuple[str, str]: + if agent_id in self.agents: + if ( + skip_resource_in_use_check + or self.agents[agent_id].agent_status == "PREPARED" + ): + self.agents[agent_id].agent_status = "DELETING" + agent_status = self.agents[agent_id].agent_status + del self.agents[agent_id] + else: + raise ConflictException(f"Agent {agent_id} is in use") + else: + raise ResourceNotFoundException(f"Agent {agent_id} not found") + return agent_id, agent_status + + def create_knowledge_base( + self, + name: str, + role_arn: str, + knowledge_base_configuration: Dict[str, Any], + storage_configuration: Dict[str, Any], + client_token: Optional[str], + description: Optional[str], + tags: Optional[Dict[str, str]], + ) -> KnowledgeBase: + knowledge_base = KnowledgeBase( + name, + role_arn, + self.region_name, + self.account_id, + knowledge_base_configuration, + storage_configuration, + client_token, + description, + ) + self.knowledge_bases[knowledge_base.knowledge_base_id] = knowledge_base + if tags: + self.tag_resource(knowledge_base.knowledge_base_arn, tags) + return knowledge_base + + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_knowledge_bases( + self, max_results: Optional[int], next_token: Optional[str] + ) -> List[Any]: + knowledge_base_summaries = [ + knowledge_base.dict_summary() + for knowledge_base in self.knowledge_bases.values() + ] + return knowledge_base_summaries + + def delete_knowledge_base(self, knowledge_base_id: str) -> Tuple[str, str]: + if knowledge_base_id in self.knowledge_bases: + self.knowledge_bases[knowledge_base_id].status = "DELETING" + knowledge_base_status = self.knowledge_bases[knowledge_base_id].status + del self.knowledge_bases[knowledge_base_id] + else: + raise ResourceNotFoundException( + f"Knowledge base {knowledge_base_id} not found" + ) + return knowledge_base_id, knowledge_base_status + + def get_knowledge_base(self, knowledge_base_id: str) -> KnowledgeBase: + if knowledge_base_id not in self.knowledge_bases: + raise ResourceNotFoundException( + f"Knowledge base {knowledge_base_id} not found" + ) + return self.knowledge_bases[knowledge_base_id] + + def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + tags_input = TaggingService.convert_dict_to_tags_input(tags or {}) + self.tagger.tag_resource(resource_arn, tags_input) + return + + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + self.tagger.untag_resource_using_names(resource_arn, tag_keys) + return + + def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: + if resource_arn not in self._list_arns(): + raise ResourceNotFoundException(f"Resource {resource_arn} not found") + return self.tagger.get_tag_dict_for_resource(resource_arn) + + +bedrockagent_backends = BackendDict(AgentsforBedrockBackend, "bedrock") diff --git a/moto/bedrockagent/responses.py b/moto/bedrockagent/responses.py new file mode 100644 index 000000000000..7d3bada0ea7e --- /dev/null +++ b/moto/bedrockagent/responses.py @@ -0,0 +1,153 @@ +"""Handles incoming bedrockagent requests, invokes methods, returns responses.""" + +import json +from urllib.parse import unquote + +from moto.core.responses import BaseResponse + +from .models import AgentsforBedrockBackend, bedrockagent_backends + + +class AgentsforBedrockResponse(BaseResponse): + """Handler for AgentsforBedrock requests and responses.""" + + def __init__(self) -> None: + super().__init__(service_name="bedrock-agent") + + @property + def bedrockagent_backend(self) -> AgentsforBedrockBackend: + """Return backend instance specific for this region.""" + return bedrockagent_backends[self.current_account][self.region] + + def create_agent(self) -> str: + params = json.loads(self.body) + agent_name = params.get("agentName") + client_token = params.get("clientToken") + instruction = params.get("instruction") + foundation_model = params.get("foundationModel") + description = params.get("description") + idle_session_ttl_in_seconds = params.get("idleSessionTTLInSeconds") + agent_resource_role_arn = params.get("agentResourceRoleArn") + customer_encryption_key_arn = params.get("customerEncryptionKeyArn") + tags = params.get("tags") + prompt_override_configuration = params.get("promptOverrideConfiguration") + agent = self.bedrockagent_backend.create_agent( + agent_name=agent_name, + client_token=client_token, + instruction=instruction, + foundation_model=foundation_model, + description=description, + idle_session_ttl_in_seconds=idle_session_ttl_in_seconds, + agent_resource_role_arn=agent_resource_role_arn, + customer_encryption_key_arn=customer_encryption_key_arn, + tags=tags, + prompt_override_configuration=prompt_override_configuration, + ) + return json.dumps({"agent": dict(agent.to_dict())}) + + def get_agent(self) -> str: + agent_id = self.path.split("/")[-2] + agent = self.bedrockagent_backend.get_agent(agent_id=agent_id) + return json.dumps({"agent": dict(agent.to_dict())}) + + def list_agents(self) -> str: + params = json.loads(self.body) + max_results = params.get("maxResults") + next_token = params.get("nextToken") + max_results = int(max_results) if max_results else None + agents, next_token = self.bedrockagent_backend.list_agents( + max_results=max_results, + next_token=next_token, + ) + return json.dumps( + { + "agentSummaries": agents, + "nextToken": next_token, + } + ) + + def delete_agent(self) -> str: + params = self._get_params() + skip_resource_in_use_check = params.get("skipResourceInUseCheck") + agent_id = self.path.split("/")[-2] + agent_id, agent_status = self.bedrockagent_backend.delete_agent( + agent_id=agent_id, skip_resource_in_use_check=skip_resource_in_use_check + ) + return json.dumps({"agentId": agent_id, "agentStatus": agent_status}) + + def create_knowledge_base(self) -> str: + params = json.loads(self.body) + client_token = params.get("clientToken") + name = params.get("name") + description = params.get("description") + role_arn = params.get("roleArn") + knowledge_base_configuration = params.get("knowledgeBaseConfiguration") + storage_configuration = params.get("storageConfiguration") + tags = params.get("tags") + knowledge_base = self.bedrockagent_backend.create_knowledge_base( + client_token=client_token, + name=name, + description=description, + role_arn=role_arn, + knowledge_base_configuration=knowledge_base_configuration, + storage_configuration=storage_configuration, + tags=tags, + ) + return json.dumps({"knowledgeBase": dict(knowledge_base.to_dict())}) + + def list_knowledge_bases(self) -> str: + params = json.loads(self.body) + max_results = params.get("maxResults") + next_token = params.get("nextToken") + max_results = int(max_results) if max_results else None + knowledge_bases, next_token = self.bedrockagent_backend.list_knowledge_bases( + max_results=max_results, + next_token=next_token, + ) + return json.dumps( + { + "knowledgeBaseSummaries": knowledge_bases, + "nextToken": next_token, + } + ) + + def delete_knowledge_base(self) -> str: + knowledge_base_id = self.path.split("/")[-1] + ( + knowledge_base_id, + knowledge_base_status, + ) = self.bedrockagent_backend.delete_knowledge_base( + knowledge_base_id=knowledge_base_id + ) + return json.dumps( + {"knowledgeBaseId": knowledge_base_id, "status": knowledge_base_status} + ) + + def get_knowledge_base(self) -> str: + knowledge_base_id = self.path.split("/")[-1] + knowledge_base = self.bedrockagent_backend.get_knowledge_base( + knowledge_base_id=knowledge_base_id + ) + return json.dumps({"knowledgeBase": knowledge_base.to_dict()}) + + def tag_resource(self) -> str: + params = json.loads(self.body) + resource_arn = unquote(self.path.split("/tags/")[-1]) + tags = params.get("tags") + self.bedrockagent_backend.tag_resource(resource_arn=resource_arn, tags=tags) + return json.dumps(dict()) + + def untag_resource(self) -> str: + resource_arn = unquote(self.path.split("/tags/")[-1]) + tag_keys = self.querystring.get("tagKeys", []) + self.bedrockagent_backend.untag_resource( + resource_arn=resource_arn, tag_keys=tag_keys + ) + return json.dumps(dict()) + + def list_tags_for_resource(self) -> str: + resource_arn = unquote(self.path.split("/tags/")[-1]) + tags = self.bedrockagent_backend.list_tags_for_resource( + resource_arn=resource_arn + ) + return json.dumps(dict(tags=tags)) diff --git a/moto/bedrockagent/urls.py b/moto/bedrockagent/urls.py new file mode 100644 index 000000000000..bb7ce16241a1 --- /dev/null +++ b/moto/bedrockagent/urls.py @@ -0,0 +1,11 @@ +"""bedrockagent base URL and path.""" + +from .responses import AgentsforBedrockResponse + +url_bases = [ + r"https?://bedrock-agent\.(.+)\.amazonaws\.com", +] + +url_paths = { + "{0}/.*$": AgentsforBedrockResponse.dispatch, +} diff --git a/tests/test_bedrock/__init__.py b/tests/test_bedrock/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_bedrock/test_bedrock.py b/tests/test_bedrock/test_bedrock.py new file mode 100644 index 000000000000..d6e93caf5f3a --- /dev/null +++ b/tests/test_bedrock/test_bedrock.py @@ -0,0 +1,1401 @@ +"""Unit tests for bedrock-supported APIs.""" + +from datetime import datetime +from unittest import SkipTest + +import boto3 +import pytest +from botocore.exceptions import ClientError +from freezegun import freeze_time + +from moto import mock_aws, settings + +DEFAULT_REGION = "us-east-1" + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_aws +def test_create_model_customization_job(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + resp = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ( + resp["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + + +@mock_aws +def test_get_model_customization_job(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + resp = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.get_model_customization_job(jobIdentifier="testjob") + + assert ( + resp["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + assert resp["roleArn"] == "testrole" + + +@mock_aws +def test_get_model_invocation_logging_configuration(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + logging_config = { + "cloudWatchConfig": { + "logGroupName": "Test", + "roleArn": "testrole", + "largeDataDeliveryS3Config": { + "bucketName": "testbucket", + }, + }, + "s3Config": { + "bucketName": "testconfigbucket", + }, + } + client.put_model_invocation_logging_configuration(loggingConfig=logging_config) + response = client.get_model_invocation_logging_configuration() + assert response["loggingConfig"]["cloudWatchConfig"]["logGroupName"] == "Test" + assert response["loggingConfig"]["s3Config"]["bucketName"] == "testconfigbucket" + + +@mock_aws +def test_put_model_invocation_logging_configuration(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + logging_config = { + "cloudWatchConfig": { + "logGroupName": "Test", + "roleArn": "testrole", + "largeDataDeliveryS3Config": { + "bucketName": "testbucket", + }, + }, + "s3Config": { + "bucketName": "testconfigbucket", + }, + } + # pytest.set_trace() + client.put_model_invocation_logging_configuration(loggingConfig=logging_config) + response = client.get_model_invocation_logging_configuration() + assert response["loggingConfig"]["cloudWatchConfig"]["logGroupName"] == "Test" + + +@mock_aws +def test_tag_resource_model_customization_job(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + resp = client.list_tags_for_resource(resourceARN=job_arn["jobArn"]) + assert resp["tags"][0]["key"] == "testkey" + assert resp["tags"][1]["value"] == "testvalue2" + + +@mock_aws +def test_untag_resource(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + resp = client.untag_resource(resourceARN=job_arn["jobArn"], tagKeys=["testkey"]) + resp = client.list_tags_for_resource(resourceARN=job_arn["jobArn"]) + + assert resp["tags"][0]["key"] == "testkey2" + assert resp["tags"][0]["value"] == "testvalue2" + + +@mock_aws +def test_untag_resource_custom_model(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + resp = client.untag_resource(resourceARN=job_arn["jobArn"], tagKeys=["testkey"]) + resp = client.list_tags_for_resource(resourceARN=job_arn["jobArn"]) + + assert resp["tags"][0]["key"] == "testkey2" + assert resp["tags"][0]["value"] == "testvalue2" + + +@mock_aws +def test_list_tags_for_resource(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + resp = client.list_tags_for_resource(resourceARN=job_arn["jobArn"]) + + assert resp["tags"][0]["key"] == "testkey" + assert resp["tags"][1]["value"] == "testvalue2" + + +@mock_aws +def test_get_custom_model(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.get_custom_model(modelIdentifier="testmodel") + assert ( + resp["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + assert resp["jobArn"] == job_arn["jobArn"] + + +@mock_aws +def test_get_custom_model_arn(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + + resp = client.get_custom_model(modelIdentifier="testmodel") + model = client.get_custom_model(modelIdentifier=resp["modelArn"]) + assert model["modelName"] == "testmodel" + assert resp["jobArn"] == job_arn["jobArn"] + + +@mock_aws +def test_get_custom_model_arn_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + + resp = client.get_custom_model(modelIdentifier="testmodel") + with pytest.raises(ClientError) as ex: + client.get_custom_model(modelIdentifier=(resp["modelArn"] + "no")) + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_list_custom_models(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models() + assert len(resp["modelSummaries"]) == 2 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel1" + ) + assert ( + resp["modelSummaries"][1]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel2" + ) + + +@mock_aws +def test_list_model_customization_jobs(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs() + assert len(resp["modelCustomizationJobSummaries"]) == 2 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + assert ( + resp["modelCustomizationJobSummaries"][1]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + + +@mock_aws +def test_delete_custom_model(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.delete_custom_model(modelIdentifier="testmodel") + + with pytest.raises(ClientError) as ex: + client.get_custom_model(modelIdentifier="testmodel") + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_delete_custom_model_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + + with pytest.raises(ClientError) as ex: + client.delete_custom_model(modelIdentifier="testmodel1") + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_stop_model_customization_job(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.stop_model_customization_job(jobIdentifier="testjob") + resp = client.get_model_customization_job(jobIdentifier="testjob") + assert resp["status"] == "Stopped" + + +@mock_aws +def test_delete_model_invocation_logging_configuration(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + logging_config = { + "cloudWatchConfig": { + "logGroupName": "Test", + "roleArn": "testrole", + "largeDataDeliveryS3Config": { + "bucketName": "testbucket", + }, + }, + "s3Config": { + "bucketName": "testconfigbucket", + }, + } + client.put_model_invocation_logging_configuration(loggingConfig=logging_config) + client.delete_model_invocation_logging_configuration() + assert client.get_model_invocation_logging_configuration()["loggingConfig"] == {} + + +@mock_aws +def test_create_model_customization_job_bad_training_data_config(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + with pytest.raises(ClientError) as ex: + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "aws:s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_create_model_customization_job_bad_validation_data_config(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + with pytest.raises(ClientError) as ex: + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={ + "validators": [{"s3Uri": "aws:s3://validation_bucket"}] + }, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_create_model_customization_job_bad_output_data_config(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + with pytest.raises(ClientError) as ex: + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "aws:s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_create_model_customization_job_duplicate_job_name(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ex.value.response["Error"]["Code"] == "ResourceInUseException" + + +@mock_aws +def test_create_model_customization_job_duplicate_model_name(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + client.create_model_customization_job( + jobName="testjob1", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + assert ex.value.response["Error"]["Code"] == "ResourceInUseException" + + +@mock_aws +def test_create_model_customization_job_tags(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + # create a test s3 client and bucket + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + # pytest.set_trace() + resp = client.create_model_customization_job( + jobName="testjob1", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + validationDataConfig={"validators": [{"s3Uri": "s3://validation_bucket"}]}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + jobTags=[{"key": "test", "value": "testvalue"}], + customModelTags=[{"key": "modeltest", "value": "modeltestvalue"}], + ) + job_tags = client.list_tags_for_resource(resourceARN=resp["jobArn"]) + model_arn = client.list_custom_models()["modelSummaries"][0]["modelArn"] + model_tags = client.list_tags_for_resource(resourceARN=model_arn) + assert job_tags["tags"] == [{"key": "test", "value": "testvalue"}] + assert model_tags["tags"] == [{"key": "modeltest", "value": "modeltestvalue"}] + + +@mock_aws +def test_get_model_customization_job_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.get_model_customization_job(jobIdentifier="testjob1") + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_stop_model_customization_job_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.stop_model_customization_job(jobIdentifier="testjob1") + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_list_model_customization_jobs_max_results(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs(maxResults=1) + assert len(resp["modelCustomizationJobSummaries"]) == 1 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + + +@mock_aws +def test_list_model_customization_jobs_name_contains(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs(nameContains="testjob2") + assert len(resp["modelCustomizationJobSummaries"]) == 1 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + + +@mock_aws +def test_list_model_customization_jobs_creation_time_before(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs( + creationTimeBefore=datetime(2022, 2, 1, 12, 0, 0) + ) + assert len(resp["modelCustomizationJobSummaries"]) == 1 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + + +@mock_aws +def test_list_model_customization_jobs_creation_time_after(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs( + creationTimeAfter=datetime(2022, 2, 1, 12, 0, 0) + ) + assert len(resp["modelCustomizationJobSummaries"]) == 1 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + + +@mock_aws +def test_list_model_customization_jobs_status(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs(statusEquals="InProgress") + assert len(resp["modelCustomizationJobSummaries"]) == 2 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + assert ( + resp["modelCustomizationJobSummaries"][1]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + + +@mock_aws +def test_list_model_customization_jobs_ascending_sort(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs( + sortBy="CreationTime", sortOrder="Ascending" + ) + assert len(resp["modelCustomizationJobSummaries"]) == 3 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + assert ( + resp["modelCustomizationJobSummaries"][1]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + assert ( + resp["modelCustomizationJobSummaries"][2]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob3" + ) + + +@mock_aws +def test_list_model_customization_jobs_descending_sort(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_model_customization_jobs( + sortBy="CreationTime", sortOrder="Descending" + ) + assert len(resp["modelCustomizationJobSummaries"]) == 3 + assert ( + resp["modelCustomizationJobSummaries"][0]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob3" + ) + assert ( + resp["modelCustomizationJobSummaries"][1]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob2" + ) + assert ( + resp["modelCustomizationJobSummaries"][2]["jobArn"] + == "arn:aws:bedrock:us-east-1:123456789012:model-customization-job/testjob" + ) + + +@mock_aws +def test_list_model_customization_jobs_bad_sort_order(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.list_model_customization_jobs( + sortBy="CreationTime", sortOrder="decending" + ) + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_list_model_customization_jobs_bad_sort_by(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.list_model_customization_jobs( + sortBy="Creationime", sortOrder="Descending" + ) + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_get_model_invocation_logging_configuration_empty(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + response = client.get_model_invocation_logging_configuration() + assert response["loggingConfig"] == {} + + +@mock_aws +def test_list_custom_models_max_results(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(maxResults=1) + assert len(resp["modelSummaries"]) == 1 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + + +@mock_aws +def test_list_custom_models_name_contains(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(nameContains="testjob2") + assert len(resp["modelSummaries"]) == 1 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel2" + ) + + +@mock_aws +def test_list_custom_models_creation_time_before(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(creationTimeBefore=datetime(2022, 2, 1, 12, 0, 0)) + assert len(resp["modelSummaries"]) == 1 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + + +@mock_aws +def test_list_custom_models_creation_time_after(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(creationTimeAfter=datetime(2022, 2, 1, 12, 0, 0)) + assert len(resp["modelSummaries"]) == 1 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel2" + ) + + +@mock_aws +def test_list_custom_models_ascending_sort(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(sortBy="CreationTime", sortOrder="Ascending") + assert len(resp["modelSummaries"]) == 3 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + assert ( + resp["modelSummaries"][1]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel2" + ) + assert ( + resp["modelSummaries"][2]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel3" + ) + + +@mock_aws +def test_list_custom_models_descending_sort(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models(sortBy="CreationTime", sortOrder="Descending") + assert len(resp["modelSummaries"]) == 3 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel3" + ) + assert ( + resp["modelSummaries"][1]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel2" + ) + assert ( + resp["modelSummaries"][2]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + + +@mock_aws +def test_list_custom_models_bad_sort_order(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.list_custom_models(sortBy="CreationTime", sortOrder="decending") + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_list_custom_models_bad_sort_by(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + with freeze_time("2022-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel1", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob3", + customModelName="testmodel3", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with freeze_time("2023-01-01 12:00:00"): + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.list_custom_models(sortBy="Creationime", sortOrder="Descending") + assert ex.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_list_custom_models_base_model_arn_equals(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.create_model_customization_job( + jobName="testjob2", + customModelName="testmodel2", + roleArn="testrole", + baseModelIdentifier="amazon.titan-text-lite-v1", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + resp = client.list_custom_models( + baseModelArnEquals="arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0" + ) + assert len(resp["modelSummaries"]) == 1 + assert ( + resp["modelSummaries"][0]["modelArn"] + == "arn:aws:bedrock:us-east-1:123456789012:custom-model/testmodel" + ) + + +@mock_aws +def test_tag_resource_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.tag_resource( + resourceARN=job_arn["jobArn"] + "no", + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_tag_resource_too_many(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + with pytest.raises(ClientError) as ex: + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[{"key": f"testkey{i}", "value": f"testvalue{i}"} for i in range(51)], + ) + assert ex.value.response["Error"]["Code"] == "TooManyTagsException" + + +@mock_aws +def test_untag_resource_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + with pytest.raises(ClientError) as ex: + client.untag_resource(resourceARN=job_arn["jobArn"] + "no", tagKeys=["testkey"]) + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_list_tags_for_resource_not_found(): + client = boto3.client("bedrock", region_name=DEFAULT_REGION) + s3_client = boto3.client("s3", region_name=DEFAULT_REGION) + s3_client.create_bucket(Bucket="training_bucket") + s3_client.create_bucket(Bucket="output_bucket") + job_arn = client.create_model_customization_job( + jobName="testjob", + customModelName="testmodel", + roleArn="testrole", + baseModelIdentifier="anthropic.claude-3-sonnet-20240229-v1:0", + trainingDataConfig={"s3Uri": "s3://training_bucket"}, + outputDataConfig={"s3Uri": "s3://output_bucket"}, + hyperParameters={"learning_rate": "0.01"}, + ) + client.tag_resource( + resourceARN=job_arn["jobArn"], + tags=[ + {"key": "testkey", "value": "testvalue"}, + {"key": "testkey2", "value": "testvalue2"}, + ], + ) + with pytest.raises(ClientError) as ex: + client.list_tags_for_resource(resourceARN=job_arn["jobArn"] + "no") + assert ex.value.response["Error"]["Code"] == "ResourceNotFoundException" diff --git a/tests/test_bedrockagent/__init__.py b/tests/test_bedrockagent/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_bedrockagent/test_bedrockagent.py b/tests/test_bedrockagent/test_bedrockagent.py new file mode 100644 index 000000000000..3fa9cc36fe0f --- /dev/null +++ b/tests/test_bedrockagent/test_bedrockagent.py @@ -0,0 +1,726 @@ +"""Unit tests for bedrockagent-supported APIs.""" + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from moto import mock_aws + +DEFAULT_REGION = "us-east-1" + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_aws +def test_create_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + # pytest.set_trace() + + resp = client.create_agent( + agentName="agent_name", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + # resp = client.create_agent( + # agentName="agent_name", + # clientToken="client_tokenclient_tokenclient_token", + # instruction="instructioninstructioninstructioninstruction", + # foundationModel="foundation_model", + # description="description", + # idleSessionTTLInSeconds=60, + # agentResourceRoleArn="agent_resource_role_arn", + # customerEncryptionKeyArn="customer_encryption_key_arn", + # tags={ + # "Key": "test-tag-key", + # }, + # promptOverrideConfiguration={ + # 'promptConfigurations': [ + # { + # 'promptType': 'PRE_PROCESSING', + # 'promptCreationMode': 'DEFAULT', + # 'promptState': 'ENABLED', + # 'basePromptTemplate': 'string', + # 'inferenceConfiguration': { + # 'temperature': 1.0, + # 'topP': 1.0, + # 'topK': 123, + # 'maximumLength': 123, + # 'stopSequences': [ + # 'string', + # ] + # }, + # 'parserMode': 'DEFAULT' + # }, + # ], + # 'overrideLambda': 'overrideLambdaoverrideLambdaoverrideLambdaoverrideLambda' + # } + # ) + assert resp["agent"]["agentName"] == "agent_name" + + +@mock_aws +def test_get_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + resp = client.get_agent(agentId=agent["agent"]["agentId"]) + assert resp["agent"]["agentName"] == "testname" + + +@mock_aws +def test_get_agent_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + with pytest.raises(ClientError) as e: + client.get_agent(agentId="non-existent-agent-id") + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_list_agents(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_agent( + agentName="testname1", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + client.create_agent( + agentName="testname2", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + resp = client.list_agents() + # assert resp['agent']['agentName'] == "testname" + + assert len(resp["agentSummaries"]) == 2 + assert resp["agentSummaries"][0]["agentName"] == "testname1" + assert resp["agentSummaries"][1]["agentName"] == "testname2" + + +@mock_aws +def test_delete_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + tags={"Key": "test-tag-key"}, + ) + agent_id = agent["agent"]["agentId"] + resp = client.delete_agent(agentId=agent_id, skipResourceInUseCheck=True) + + assert resp["agentId"] == agent_id + assert resp["agentStatus"] == "DELETING" + + +@mock_aws +def test_create_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + resp = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + + assert resp["knowledgeBase"]["name"] == "testkb" + + +@mock_aws +def test_list_knowledge_bases(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_knowledge_base( + name="testkb1", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + client.create_knowledge_base( + name="testkb2", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + resp = client.list_knowledge_bases() + + assert len(resp["knowledgeBaseSummaries"]) == 2 + assert resp["knowledgeBaseSummaries"][0]["name"] == "testkb1" + assert resp["knowledgeBaseSummaries"][1]["name"] == "testkb2" + + +@mock_aws +def test_delete_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + kb_id = kb["knowledgeBase"]["knowledgeBaseId"] + resp = client.delete_knowledge_base(knowledgeBaseId=kb_id) + assert resp["knowledgeBaseId"] == kb_id + assert resp["status"] == "DELETING" + + +@mock_aws +def test_delete_knowledge_base_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + with pytest.raises(ClientError) as e: + client.delete_knowledge_base(knowledgeBaseId="non-existent-kb-id") + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_get_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + kb_id = kb["knowledgeBase"]["knowledgeBaseId"] + resp = client.get_knowledge_base(knowledgeBaseId=kb_id) + assert resp["knowledgeBase"]["name"] == "testkb" + + +@mock_aws +def test_get_knowledge_base_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + with pytest.raises(ClientError) as e: + client.get_knowledge_base(knowledgeBaseId="non-existent-kb-id") + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_tag_resource_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + ) + agent_arn = agent["agent"]["agentArn"] + resp = client.tag_resource(resourceArn=agent_arn, tags={"Key": "test-tag"}) + resp = client.list_tags_for_resource(resourceArn=agent_arn) + assert resp["tags"]["Key"] == "test-tag" + + +@mock_aws +def test_tag_resource_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + ) + kb_arn = kb["knowledgeBase"]["knowledgeBaseArn"] + resp = client.tag_resource(resourceArn=kb_arn, tags={"Key": "test-tag"}) + resp = client.list_tags_for_resource(resourceArn=kb_arn) + assert resp["tags"]["Key"] == "test-tag" + + +@mock_aws +def test_untag_resource_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + ) + agent_arn = agent["agent"]["agentArn"] + resp = client.tag_resource( + resourceArn=agent_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + resp = client.untag_resource(resourceArn=agent_arn, tagKeys=["Key1"]) + resp = client.list_tags_for_resource(resourceArn=agent_arn) + assert len(resp["tags"]) == 1 + assert resp["tags"]["Key2"] == "test-tag2" + + +@mock_aws +def test_untag_resource_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + ) + kb_arn = kb["knowledgeBase"]["knowledgeBaseArn"] + resp = client.tag_resource( + resourceArn=kb_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + resp = client.untag_resource(resourceArn=kb_arn, tagKeys=["Key1", "Key2"]) + resp = client.list_tags_for_resource(resourceArn=kb_arn) + assert len(resp["tags"]) == 0 + + +@mock_aws +def test_list_tags_for_resource_agent(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + ) + agent_arn = agent["agent"]["agentArn"] + resp = client.tag_resource( + resourceArn=agent_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + resp = client.list_tags_for_resource(resourceArn=agent_arn) + assert resp["tags"]["Key1"] == "test-tag" + assert resp["tags"]["Key2"] == "test-tag2" + + +@mock_aws +def test_list_tags_for_resource_knowledge_base(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + ) + kb_arn = kb["knowledgeBase"]["knowledgeBaseArn"] + resp = client.tag_resource( + resourceArn=kb_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + resp = client.list_tags_for_resource(resourceArn=kb_arn) + assert resp["tags"]["Key1"] == "test-tag" + assert resp["tags"]["Key2"] == "test-tag2" + + +@mock_aws +def test_create_knowledge_base_bad_knowledge_base_config(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + with pytest.raises(ClientError) as e: + client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "vECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + assert e.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_create_knowledge_base_bad_storage_config(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + with pytest.raises(ClientError) as e: + client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "oPENSEARCH_SERVERLES", + }, + tags={"Key": "test-tag"}, + ) + assert e.value.response["Error"]["Code"] == "ValidationException" + + +# @mock_aws +# def test_list_agents_token(): +# client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) +# client.create_agent( +# agentName="testname1", +# agentResourceRoleArn="test-agent-arn", +# tags={ +# "Key": "test-tag-key", +# }, +# ) +# client.create_agent( +# agentName="testname2", +# agentResourceRoleArn="test-agent-arn", +# tags={ +# "Key": "test-tag-key", +# }, +# ) +# resp = client.list_agents(nextToken="1") + +# assert len(resp["agentSummaries"]) == 1 +# assert resp["agentSummaries"][0]["agentName"] == "testname2" + + +# @mock_aws +# def test_list_agents_bad_token(): +# client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) +# client.create_agent( +# agentName="testname1", +# agentResourceRoleArn="test-agent-arn", +# tags={ +# "Key": "test-tag-key", +# }, +# ) +# client.create_agent( +# agentName="testname2", +# agentResourceRoleArn="test-agent-arn", +# tags={ +# "Key": "test-tag-key", +# }, +# ) +# with pytest.raises(ClientError) as e: +# client.list_agents(nextToken="3") +# assert e.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_list_agents_max_results(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_agent( + agentName="testname1", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + client.create_agent( + agentName="testname2", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + + resp = client.list_agents(maxResults=1) + + assert len(resp["agentSummaries"]) == 1 + assert resp["agentSummaries"][0]["agentName"] == "testname1" + + +@mock_aws +def test_list_agents_big_max_results(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_agent( + agentName="testname1", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + client.create_agent( + agentName="testname2", + agentResourceRoleArn="test-agent-arn", + tags={ + "Key": "test-tag-key", + }, + ) + + resp = client.list_agents(maxResults=4) + + assert len(resp["agentSummaries"]) == 2 + assert resp["agentSummaries"][0]["agentName"] == "testname1" + assert resp["agentSummaries"][1]["agentName"] == "testname2" + + +@mock_aws +def test_delete_agent_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + tags={"Key": "test-tag-key"}, + ) + with pytest.raises(ClientError) as e: + client.delete_agent( + agentId="non-existent-agent-id", skipResourceInUseCheck=True + ) + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +# @mock_aws +# def test_delete_agent_in_use(): +# client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) +# agent = client.create_agent( +# agentName="testname", +# agentResourceRoleArn="test-agent-arn", +# tags={"Key": "test-tag-key"}, +# ) +# agent_id = agent["agent"]["agentId"] +# with patch("moto.bedrockagent.models.Agent.agent_status", return_value="IN_USE"): +# # mock = Mock(spec = Agent) +# # mock.agent_status.return_value = "IN_USE" +# with pytest.raises(ClientError) as e: +# resp = client.delete_agent(agentId=agent_id, skipResourceInUseCheck=False) +# assert e.value.response["Error"]["Code"] == "ConflictException" + + +# @mock_aws +# def test_list_knowledge_bases_token(): +# client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) +# resp = client.create_knowledge_base( +# name="testkb", +# description="description", +# roleArn="test_role_arn", +# knowledgeBaseConfiguration={ +# "type": "VECTOR", +# }, +# storageConfiguration={ +# "type": "OPENSEARCH_SERVERLESS", +# }, +# tags={"Key": "test-tag"}, +# ) +# resp = client.create_knowledge_base( +# name="testkb2", +# description="description", +# roleArn="test_role_arn", +# knowledgeBaseConfiguration={ +# "type": "VECTOR", +# }, +# storageConfiguration={ +# "type": "OPENSEARCH_SERVERLESS", +# }, +# tags={"Key": "test-tag"}, +# ) +# resp = client.list_knowledge_bases(nextToken="1") + +# assert len(resp["knowledgeBaseSummaries"]) == 1 +# assert resp["knowledgeBaseSummaries"][0]["name"] == "testkb2" + + +# @mock_aws +# def test_list_knowledge_bases_bad_token(): +# client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) +# client.create_knowledge_base( +# name="testkb", +# description="description", +# roleArn="test_role_arn", +# knowledgeBaseConfiguration={ +# "type": "VECTOR", +# }, +# storageConfiguration={ +# "type": "OPENSEARCH_SERVERLESS", +# }, +# tags={"Key": "test-tag"}, +# ) +# client.create_knowledge_base( +# name="testkb2", +# description="description", +# roleArn="test_role_arn", +# knowledgeBaseConfiguration={ +# "type": "VECTOR", +# }, +# storageConfiguration={ +# "type": "OPENSEARCH_SERVERLESS", +# }, +# tags={"Key": "test-tag"}, +# ) +# with pytest.raises(ClientError) as e: +# client.list_knowledge_bases(nextToken="3") +# assert e.value.response["Error"]["Code"] == "ValidationException" + + +@mock_aws +def test_list_knowledge_bases_max_results(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + resp = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + resp = client.create_knowledge_base( + name="testkb2", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + + resp = client.list_knowledge_bases(maxResults=1) + + assert len(resp["knowledgeBaseSummaries"]) == 1 + assert resp["knowledgeBaseSummaries"][0]["name"] == "testkb" + + +@mock_aws +def test_list_knowledge_bases_big_max_results(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + resp = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + resp = client.create_knowledge_base( + name="testkb2", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + tags={"Key": "test-tag"}, + ) + + resp = client.list_knowledge_bases(maxResults=4) + + assert len(resp["knowledgeBaseSummaries"]) == 2 + assert resp["knowledgeBaseSummaries"][0]["name"] == "testkb" + assert resp["knowledgeBaseSummaries"][1]["name"] == "testkb2" + + +@mock_aws +def test_tag_resource_knowledge_base_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + kb = client.create_knowledge_base( + name="testkb", + description="description", + roleArn="test_role_arn", + knowledgeBaseConfiguration={ + "type": "VECTOR", + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + }, + ) + kb_arn = kb["knowledgeBase"]["knowledgeBaseArn"] + with pytest.raises(ClientError) as e: + client.tag_resource(resourceArn=kb_arn + "no", tags={"Key": "test-tag"}) + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_untag_resource_agent_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + ) + agent_arn = agent["agent"]["agentArn"] + client.tag_resource( + resourceArn=agent_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + with pytest.raises(ClientError) as e: + client.untag_resource(resourceArn=agent_arn + "no", tagKeys=["Key1"]) + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException" + + +@mock_aws +def test_list_tags_for_resource_agent_not_found(): + client = boto3.client("bedrock-agent", region_name=DEFAULT_REGION) + agent = client.create_agent( + agentName="testname", + agentResourceRoleArn="test-agent-arn", + ) + agent_arn = agent["agent"]["agentArn"] + client.tag_resource( + resourceArn=agent_arn, tags={"Key1": "test-tag", "Key2": "test-tag2"} + ) + with pytest.raises(ClientError) as e: + client.list_tags_for_resource(resourceArn=agent_arn + "no") + assert e.value.response["Error"]["Code"] == "ResourceNotFoundException"