Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use cached boto3 clients in ECSWorker #375

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Use cached boto3 clients in `ECSWorker` - [#375](https://github.com/PrefectHQ/prefect-aws/pull/375)

### Fixed

### Deprecated

### Removed

## 0.4.9
Released January 24rd, 2024;

Released January 24rd, 2024.

### Fixed

- Hashing of nested objects within `AwsClientParameters` - [#373](https://github.com/PrefectHQ/prefect-aws/pull/373)


## 0.4.8
Released January 23rd, 2024;

Released January 23rd, 2024.

### Added

Expand Down
49 changes: 20 additions & 29 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@
import sys
import time
from copy import deepcopy
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
from uuid import UUID

import anyio
import anyio.abc
import boto3
import yaml
from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound
from prefect.server.schemas.core import FlowRun
Expand All @@ -79,7 +78,7 @@
from tenacity import retry, stop_after_attempt, wait_fixed, wait_random
from typing_extensions import Literal

from prefect_aws import AwsCredentials
from prefect_aws.credentials import AwsCredentials, ClientType

# Internal type alias for ECS clients which are generated dynamically in botocore
_ECSClient = Any
Expand Down Expand Up @@ -584,8 +583,8 @@ async def run(
"""
Runs a given flow run on the current worker.
"""
boto_session, ecs_client = await run_sync_in_worker_thread(
self._get_session_and_client, configuration
ecs_client = await run_sync_in_worker_thread(
self._get_client, configuration, "ecs"
)

logger = self.get_flow_run_logger(flow_run)
Expand All @@ -598,7 +597,6 @@ async def run(
) = await run_sync_in_worker_thread(
self._create_task_and_wait_for_start,
logger,
boto_session,
ecs_client,
configuration,
flow_run,
Expand All @@ -625,7 +623,6 @@ async def run(
cluster_arn,
task_definition,
is_new_task_definition and configuration.auto_deregister_task_definition,
boto_session,
ecs_client,
)

Expand All @@ -636,21 +633,17 @@ async def run(
status_code=status_code if status_code is not None else -1,
)

def _get_session_and_client(
self,
configuration: ECSJobConfiguration,
) -> Tuple[boto3.Session, _ECSClient]:
def _get_client(
self, configuration: ECSJobConfiguration, client_type: Union[str, ClientType]
) -> _ECSClient:
"""
Retrieve a boto3 session and ECS client
Get a boto3 client of client_type. Will use a cached client if one exists.
"""
boto_session = configuration.aws_credentials.get_boto3_session()
ecs_client = boto_session.client("ecs")
return boto_session, ecs_client
return configuration.aws_credentials.get_client(client_type)

def _create_task_and_wait_for_start(
self,
logger: logging.Logger,
boto_session: boto3.Session,
ecs_client: _ECSClient,
configuration: ECSJobConfiguration,
flow_run: FlowRun,
Expand Down Expand Up @@ -741,7 +734,6 @@ def _create_task_and_wait_for_start(

# Prepare the task run request
task_run_request = self._prepare_task_run_request(
boto_session,
configuration,
task_definition,
task_definition_arn,
Expand Down Expand Up @@ -782,7 +774,6 @@ def _watch_task_and_get_exit_code(
cluster_arn: str,
task_definition: dict,
deregister_task_definition: bool,
boto_session: boto3.Session,
ecs_client: _ECSClient,
) -> Optional[int]:
"""
Expand All @@ -798,7 +789,6 @@ def _watch_task_and_get_exit_code(
cluster_arn,
task_definition,
ecs_client,
boto_session,
)

if deregister_task_definition:
Expand Down Expand Up @@ -992,7 +982,6 @@ def _wait_for_task_finish(
cluster_arn: str,
task_definition: dict,
ecs_client: _ECSClient,
boto_session: boto3.Session,
):
"""
Watch an ECS task until it reaches a STOPPED status.
Expand Down Expand Up @@ -1031,7 +1020,7 @@ def _wait_for_task_finish(
else:
# Prepare to stream the output
log_config = container_def["logConfiguration"]["options"]
logs_client = boto_session.client("logs")
logs_client = self._get_client(configuration, "logs")
can_stream_output = True
# Track the last log timestamp to prevent double display
last_log_timestamp: Optional[int] = None
Expand Down Expand Up @@ -1300,13 +1289,13 @@ def _prepare_task_definition(
return task_definition

def _load_network_configuration(
self, vpc_id: Optional[str], boto_session: boto3.Session
self, vpc_id: Optional[str], configuration: ECSJobConfiguration
) -> dict:
"""
Load settings from a specific VPC or the default VPC and generate a task
run request's network configuration.
"""
ec2_client = boto_session.client("ec2")
ec2_client = self._get_client(configuration, "ec2")
vpc_message = "the default VPC" if not vpc_id else f"VPC with ID {vpc_id}"

if not vpc_id:
Expand Down Expand Up @@ -1347,13 +1336,16 @@ def _load_network_configuration(
}

def _custom_network_configuration(
self, vpc_id: str, network_configuration: dict, boto_session: boto3.Session
self,
vpc_id: str,
network_configuration: dict,
configuration: ECSJobConfiguration,
) -> dict:
"""
Load settings from a specific VPC or the default VPC and generate a task
run request's network configuration.
"""
ec2_client = boto_session.client("ec2")
ec2_client = self._get_client(configuration, "ec2")
vpc_message = f"VPC with ID {vpc_id}"

vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs")
Expand Down Expand Up @@ -1389,7 +1381,6 @@ def _custom_network_configuration(

def _prepare_task_run_request(
self,
boto_session: boto3.Session,
configuration: ECSJobConfiguration,
task_definition: dict,
task_definition_arn: str,
Expand Down Expand Up @@ -1422,7 +1413,7 @@ def _prepare_task_run_request(
and not configuration.network_configuration
):
task_run_request["networkConfiguration"] = self._load_network_configuration(
configuration.vpc_id, boto_session
configuration.vpc_id, configuration
)

# Use networkConfiguration if supplied by user
Expand All @@ -1435,7 +1426,7 @@ def _prepare_task_run_request(
self._custom_network_configuration(
configuration.vpc_id,
configuration.network_configuration,
boto_session,
configuration,
)
)

Expand Down Expand Up @@ -1628,7 +1619,7 @@ def _stop_task(
f"{cluster!r}."
)

_, ecs_client = self._get_session_and_client(configuration)
ecs_client = self._get_client(configuration, "ecs")
try:
ecs_client.stop_task(cluster=cluster, task=task)
except Exception as exc:
Expand Down
20 changes: 20 additions & 0 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from tenacity import RetryError

from prefect_aws.credentials import _get_client_cached
from prefect_aws.workers.ecs_worker import (
_TASK_DEFINITION_CACHE,
ECS_DEFAULT_CONTAINER_NAME,
Expand Down Expand Up @@ -2183,6 +2184,25 @@ async def test_retry_on_failed_task_start(
assert run_task_mock.call_count == 3


@pytest.mark.usefixtures("ecs_mocks")
async def test_worker_uses_cached_boto3_client(aws_credentials: AwsCredentials):
configuration = await construct_configuration(
aws_credentials=aws_credentials,
)

_get_client_cached.cache_clear()

assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0"

async with ECSWorker(work_pool_name="test") as worker:
worker._get_client(configuration, "ecs")
worker._get_client(configuration, "ecs")
worker._get_client(configuration, "ecs")

assert _get_client_cached.cache_info().misses == 1
assert _get_client_cached.cache_info().hits == 2


async def test_mask_sensitive_env_values():
task_run_request = {
"overrides": {
Expand Down
Loading