Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix s3 session creation in deployment steps push_to_s3 and pull_from_s3
Browse files Browse the repository at this point in the history
  • Loading branch information
markbruning committed Oct 10, 2023
1 parent 525c917 commit b3032e4
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 18 deletions.
68 changes: 51 additions & 17 deletions prefect_aws/deployments/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,7 @@ def push_to_s3(
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})
client = boto3.client(
"s3", **credentials, **client_parameters, config=Config(**advanced_config)
)
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -117,7 +110,7 @@ def push_to_s3(
continue
elif not local_file_path.is_dir():
remote_file_path = Path(folder) / local_file_path.relative_to(local_path)
client.upload_file(
s3.upload_file(
str(local_file_path), bucket, str(remote_file_path.as_posix())
)

Expand Down Expand Up @@ -174,14 +167,7 @@ def pull_from_s3(
credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}"
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})

session = boto3.Session(**credentials)
s3 = session.client("s3", **client_parameters, config=Config(**advanced_config))
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -206,3 +192,51 @@ def pull_from_s3(
"folder": folder,
"directory": str(local_path),
}


def get_s3_client(
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
) -> dict:
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}

# Get credentials from credentials (regardless if block or not)
aws_access_key_id = credentials.get("aws_access_key_id", None)
aws_secret_access_key = credentials.get("aws_secret_access_key", None)
aws_session_token = credentials.get("aws_session_token", None)

# Get remaining session info from credentials, or client_parameters
profile_name = credentials.get(
"profile_name", client_parameters.get("profile_name", None)
)
region_name = credentials.get(
"region_name", client_parameters.get("region_name", None)
)

# Get additional info from client_parameters, otherwise credentials input (if block)
aws_client_parameters = credentials.get("aws_client_parameters", client_parameters)
api_version = aws_client_parameters.get("api_version", None)
endpoint_url = aws_client_parameters.get("endpoint_url", None)
use_ssl = aws_client_parameters.get("use_ssl", None)
verify = aws_client_parameters.get("verify", None)
config_params = aws_client_parameters.get("config", {})
config = Config(**config_params)

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=profile_name,
region_name=region_name,
)
return session.client(
"s3",
api_version=api_version,
endpoint_url=endpoint_url,
use_ssl=use_ssl,
verify=verify,
config=config,
)
63 changes: 62 additions & 1 deletion tests/deploments/test_steps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import sys
from pathlib import Path, PurePath, PurePosixPath
from unittest.mock import patch

import boto3
import pytest
from moto import mock_s3

from prefect_aws.deployments.steps import pull_from_s3, push_to_s3
from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3


@pytest.fixture
Expand Down Expand Up @@ -173,6 +174,66 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials):
assert not (tmp_path / "empty2_copy").exists()


def test_s3_session_with_params():
with patch("boto3.Session") as mock_session:
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"profile_name": "foo",
"region_name": "us-weast-1",
"aws_client_parameters": {
"api_version": "v1",
"config": {"connect_timeout": 300},
},
}
)
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
},
client_parameters={
"region_name": "us-west-1",
"config": {"signature_version": "s3v4"},
},
)
all_calls = mock_session.mock_calls
assert len(all_calls) == 4
assert all_calls[0].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": "foo",
"region_name": "us-weast-1",
}
assert all_calls[1].args[0] == "s3"
assert {
"api_version": "v1",
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[1].kwargs.items()
assert all_calls[1].kwargs.get("config").connect_timeout == 300
assert all_calls[1].kwargs.get("config").signature_version is None
assert all_calls[2].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": None,
"region_name": "us-west-1",
}
assert all_calls[3].args[0] == "s3"
assert {
"api_version": None,
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[3].kwargs.items()
assert all_calls[3].kwargs.get("config").connect_timeout == 60
assert all_calls[3].kwargs.get("config").signature_version == "s3v4"


def test_custom_credentials_and_client_parameters(s3_setup, tmp_files):
s3, bucket_name = s3_setup
folder = "my-project"
Expand Down

0 comments on commit b3032e4

Please sign in to comment.