Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
add most recent task in family as cache miss fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
kevingrismore committed Jan 22, 2024
1 parent 7a25492 commit 61c6abc
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
vpc_id: Optional[str] = Field(default=None)
container_name: Optional[str] = Field(default=None)
cluster: Optional[str] = Field(default=None)
match_latest_revision_in_family: bool = Field(default=False)

@root_validator
def task_run_request_requires_arn_if_no_task_definition_given(cls, values) -> dict:
Expand Down Expand Up @@ -551,6 +552,16 @@ class ECSVariables(BaseVariables):
"your AWS account, instead it will be marked as INACTIVE."
),
)
match_latest_revision_in_family: bool = Field(
default=False,
description=(
"If enabled, the most recent active revision in the task definition "
"family will be compared against the desired ECS task configuration. "
"If they are equal, the existing task definition will be used instead "
"of registering a new one. If no family is specified the default family "
f'"{ECS_DEFAULT_FAMILY}" will be used.'
),
)


class ECSWorkerResult(BaseWorkerResult):
Expand Down Expand Up @@ -668,13 +679,14 @@ def _create_task_and_wait_for_start(
new_task_definition_registered = False

if not task_definition_arn:
cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(
flow_run.deployment_id
)
task_definition = self._prepare_task_definition(
configuration, region=ecs_client.meta.region_name
)

cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(
flow_run.deployment_id
)

if cached_task_definition_arn:
# Read the task definition to see if the cached task definition is valid
try:
Expand Down Expand Up @@ -710,13 +722,39 @@ def _create_task_and_wait_for_start(
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None

# use the family as a fallback if we don't have a local cached definition
if (
configuration.match_latest_revision_in_family
and not cached_task_definition_arn
):
try:
task_definition_from_family = self._retrieve_task_definition(
logger,
ecs_client,
task_definition.get("family", ECS_DEFAULT_FAMILY),
)
except Exception as exc:
logger.warning(
"Failed to retrieve a definition for task family "
f"{task_definition.get('family', ECS_DEFAULT_FAMILY)!r}: "
f"{exc!r}"
)
else:
if self._task_definitions_equal(
task_definition, task_definition_from_family
):
cached_task_definition_arn = task_definition_from_family[
"taskDefinitionArn"
]

if not cached_task_definition_arn:
task_definition_arn = self._register_task_definition(
logger, ecs_client, task_definition
)
new_task_definition_registered = True
else:
task_definition_arn = cached_task_definition_arn

else:
task_definition = self._retrieve_task_definition(
logger, ecs_client, task_definition_arn
Expand Down Expand Up @@ -938,15 +976,19 @@ def _retrieve_task_definition(
self,
logger: logging.Logger,
ecs_client: _ECSClient,
task_definition_arn: str,
task_definition: str,
):
"""
Retrieve an existing task definition from AWS.
"""
logger.info(f"Retrieving ECS task definition {task_definition_arn!r}...")
response = ecs_client.describe_task_definition(
taskDefinition=task_definition_arn
)
if task_definition.startswith("arn:aws:ecs:"):
logger.info(f"Retrieving ECS task definition {task_definition!r}...")
else:
logger.info(
"Retrieving most recent active revision from "
f"ECS task family {task_definition!r}..."
)
response = ecs_client.describe_task_definition(taskDefinition=task_definition)
return response["taskDefinition"]

def _wait_for_task_start(
Expand Down

0 comments on commit 61c6abc

Please sign in to comment.