From b3032e45e7501b634c2984f98309d59f3c194113 Mon Sep 17 00:00:00 2001 From: Mark Bruning Date: Tue, 10 Oct 2023 14:01:56 -0500 Subject: [PATCH] Fix s3 session creation in deployment steps push_to_s3 and pull_from_s3 --- prefect_aws/deployments/steps.py | 68 ++++++++++++++++++++++++-------- tests/deploments/test_steps.py | 63 ++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 18 deletions(-) diff --git a/prefect_aws/deployments/steps.py b/prefect_aws/deployments/steps.py index 5609a044..77930465 100644 --- a/prefect_aws/deployments/steps.py +++ b/prefect_aws/deployments/steps.py @@ -91,14 +91,7 @@ def push_to_s3( ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - client = boto3.client( - "s3", **credentials, **client_parameters, config=Config(**advanced_config) - ) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -117,7 +110,7 @@ def push_to_s3( continue elif not local_file_path.is_dir(): remote_file_path = Path(folder) / local_file_path.relative_to(local_path) - client.upload_file( + s3.upload_file( str(local_file_path), bucket, str(remote_file_path.as_posix()) ) @@ -174,14 +167,7 @@ def pull_from_s3( credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}" ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - - session = boto3.Session(**credentials) - s3 = session.client("s3", **client_parameters, config=Config(**advanced_config)) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -206,3 +192,51 @@ def pull_from_s3( "folder": folder, "directory": str(local_path), } + + +def get_s3_client( + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, +) -> dict: + if credentials is None: + credentials = {} + if client_parameters is None: + client_parameters = {} + + # Get credentials from credentials (regardless if block or not) + aws_access_key_id = credentials.get("aws_access_key_id", None) + aws_secret_access_key = credentials.get("aws_secret_access_key", None) + aws_session_token = credentials.get("aws_session_token", None) + + # Get remaining session info from credentials, or client_parameters + profile_name = credentials.get( + "profile_name", client_parameters.get("profile_name", None) + ) + region_name = credentials.get( + "region_name", client_parameters.get("region_name", None) + ) + + # Get additional info from client_parameters, otherwise credentials input (if block) + aws_client_parameters = credentials.get("aws_client_parameters", client_parameters) + api_version = aws_client_parameters.get("api_version", None) + endpoint_url = aws_client_parameters.get("endpoint_url", None) + use_ssl = aws_client_parameters.get("use_ssl", None) + verify = aws_client_parameters.get("verify", None) + config_params = aws_client_parameters.get("config", {}) + config = Config(**config_params) + + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=profile_name, + region_name=region_name, + ) + return session.client( + "s3", + api_version=api_version, + endpoint_url=endpoint_url, + use_ssl=use_ssl, + verify=verify, + config=config, + ) diff --git a/tests/deploments/test_steps.py b/tests/deploments/test_steps.py index a2312d18..19623dd3 100644 --- a/tests/deploments/test_steps.py +++ b/tests/deploments/test_steps.py @@ -1,12 +1,13 @@ import os import sys from pathlib import Path, PurePath, PurePosixPath +from unittest.mock import patch import boto3 import pytest from moto import mock_s3 -from prefect_aws.deployments.steps import pull_from_s3, push_to_s3 +from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3 @pytest.fixture @@ -173,6 +174,66 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials): assert not (tmp_path / "empty2_copy").exists() +def test_s3_session_with_params(): + with patch("boto3.Session") as mock_session: + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "profile_name": "foo", + "region_name": "us-weast-1", + "aws_client_parameters": { + "api_version": "v1", + "config": {"connect_timeout": 300}, + }, + } + ) + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + }, + client_parameters={ + "region_name": "us-west-1", + "config": {"signature_version": "s3v4"}, + }, + ) + all_calls = mock_session.mock_calls + assert len(all_calls) == 4 + assert all_calls[0].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": "foo", + "region_name": "us-weast-1", + } + assert all_calls[1].args[0] == "s3" + assert { + "api_version": "v1", + "endpoint_url": None, + "use_ssl": None, + "verify": None, + }.items() <= all_calls[1].kwargs.items() + assert all_calls[1].kwargs.get("config").connect_timeout == 300 + assert all_calls[1].kwargs.get("config").signature_version is None + assert all_calls[2].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": None, + "region_name": "us-west-1", + } + assert all_calls[3].args[0] == "s3" + assert { + "api_version": None, + "endpoint_url": None, + "use_ssl": None, + "verify": None, + }.items() <= all_calls[3].kwargs.items() + assert all_calls[3].kwargs.get("config").connect_timeout == 60 + assert all_calls[3].kwargs.get("config").signature_version == "s3v4" + + def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): s3, bucket_name = s3_setup folder = "my-project"