From 14e4c938f41a889e3ce23d9d86e62c19c6718663 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 15 Mar 2024 11:51:15 +0100 Subject: [PATCH] Support auto-setting AWS credentials for storage options --- dask_deltatable/core.py | 15 ++++++--- dask_deltatable/utils.py | 70 +++++++++++++++++++++++++++++++++++++++- dask_deltatable/write.py | 13 +++++++- tests/test_acceptance.py | 4 ++- 4 files changed, 95 insertions(+), 7 deletions(-) diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index b89e6d8..1fdd955 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -18,7 +18,7 @@ from pyarrow import dataset as pa_ds from .types import Filters -from .utils import get_partition_filters +from .utils import get_partition_filters, maybe_set_aws_credentials if Version(pa.__version__) >= Version("10.0.0"): filters_to_expression = pq.filters_to_expression @@ -94,6 +94,9 @@ def _read_from_filesystem( """ Reads the list of parquet files in parallel """ + storage_options = maybe_set_aws_credentials(path, storage_options) # type: ignore + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore + fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options) dt = DeltaTable( table_uri=path, version=version, storage_options=delta_storage_options @@ -116,12 +119,14 @@ def _read_from_filesystem( if columns: meta = meta[columns] + kws = dict(meta=meta, label="read-delta-table") + if not dd._dask_expr_enabled(): + # Setting token not supported in dask-expr + kws["token"] = tokenize(path, fs_token, **kwargs) # type: ignore return dd.from_map( partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs), pq_files, - meta=meta, - label="read-delta-table", - token=tokenize(path, fs_token, **kwargs), + **kws, ) @@ -270,6 +275,8 @@ def read_deltalake( else: if path is None: raise ValueError("Please Provide Delta Table path") + + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore resultdf = _read_from_filesystem( path=path, version=version, diff --git a/dask_deltatable/utils.py b/dask_deltatable/utils.py index 3901f63..dabb6b4 100644 --- a/dask_deltatable/utils.py +++ b/dask_deltatable/utils.py @@ -1,10 +1,78 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from .types import Filter, Filters +def get_bucket_region(path: str): + import boto3 + + if not path.startswith("s3://"): + raise ValueError(f"'{path}' is not an S3 path") + bucket = path.replace("s3://", "").split("/")[0] + resp = boto3.client("s3").get_bucket_location(Bucket=bucket) + # Buckets in region 'us-east-1' results in None, b/c why not. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location + return resp["LocationConstraint"] or "us-east-1" + + +def maybe_set_aws_credentials(path: Any, options: dict[str, Any]) -> dict[str, Any]: + """ + Maybe set AWS credentials into ``options`` if existing AWS specific keys + not found in it and path is s3:// format. + + Parameters + ---------- + path : Any + If it's a string, we'll check if it starts with 's3://' then determine bucket + region if the AWS credentials should be set. + options : dict[str, Any] + Options, any kwargs to be supplied to things like S3FileSystem or similar + that may accept AWS credentials set. A copy is made and returned if modified. + + Returns + ------- + dict + Either the original options if not modified, or a copied and updated options + with AWS credentials inserted. + """ + + is_s3_path = getattr(path, "startswith", lambda _: False)("s3://") + if not is_s3_path: + return options + + # Avoid overwriting already provided credentials + keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key") + if not any(k in (options or ()) for k in keys): + # defers installing boto3 upfront, xref _read_from_catalog + import boto3 + + session = boto3.session.Session() + credentials = session.get_credentials() + if credentials is None: + return options + region = get_bucket_region(path) + + options = (options or {}).copy() + options.update( + # Capitalized is used in delta specific API and lowercase is for S3FileSystem + dict( + # TODO: w/o this, we need to configure a LockClient which seems to require dynamodb. + AWS_S3_ALLOW_UNSAFE_RENAME="true", + AWS_SECRET_ACCESS_KEY=credentials.secret_key, + AWS_ACCESS_KEY_ID=credentials.access_key, + AWS_SESSION_TOKEN=credentials.token, + AWS_REGION=region, + secret_key=credentials.secret_key, + access_key=credentials.access_key, + token=credentials.token, + region=region, + ) + ) + return options + + def get_partition_filters( partition_columns: list[str], filters: Filters ) -> list[list[Filter]] | None: diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index 75eca45..add512d 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -15,8 +15,15 @@ from dask.dataframe.core import Scalar from dask.highlevelgraph import HighLevelGraph from deltalake import DeltaTable + +try: + from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION # type: ignore +except ImportError: + from deltalake.writer import ( + MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION, + ) + from deltalake.writer import ( - MAX_SUPPORTED_WRITER_VERSION, PYARROW_MAJOR_VERSION, AddAction, DeltaJSONEncoder, @@ -31,6 +38,7 @@ from toolz.itertoolz import pluck from ._schema import pyarrow_to_deltalake, validate_compatible +from .utils import maybe_set_aws_credentials def to_deltalake( @@ -123,6 +131,7 @@ def to_deltalake( ------- dask.Scalar """ + storage_options = maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -136,6 +145,7 @@ def to_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) + storage_options = maybe_set_aws_credentials(table_uri, storage_options) filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) if isinstance(partition_by, str): @@ -253,6 +263,7 @@ def _commit( schema = validate_compatible(schemas) assert schema if table is None: + storage_options = maybe_set_aws_credentials(table_uri, storage_options) write_deltalake_pyarrow( table_uri, schema, diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index 17e99ff..5d1b57c 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -50,7 +50,9 @@ def test_reader_all_primitive_types(): # Dask and delta go through different parquet parsers which read the # timestamp differently. This is likely a bug in arrow but the delta result # is "more correct". - expected_ddf["timestamp"] = expected_ddf["timestamp"].astype("datetime64[us]") + expected_ddf["timestamp"] = ( + expected_ddf["timestamp"].astype("datetime64[us]").dt.tz_localize("UTC") + ) assert_eq(actual_ddf, expected_ddf)