From baab00fa8950e17b48a8c337a751832bc9210b01 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 10 Jul 2023 14:24:07 +0200 Subject: [PATCH 1/2] Do not require distributed (#37) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 40950ff..65cf0f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -dask[dataframe,distribuited] +dask[dataframe] deltalake fsspec pyarrow From 08cc5b467ba3137d36a5476bf226991e0dad41ea Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 10 Jul 2023 16:45:47 +0200 Subject: [PATCH 2/2] Write operation (#29) --- .flake8 | 2 +- .pre-commit-config.yaml | 5 +- dask_deltatable/_schema.py | 325 +++++++++++++++++++++++++++++++++++++ dask_deltatable/core.py | 4 +- dask_deltatable/write.py | 261 +++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/test_write.py | 64 ++++++++ 7 files changed, 657 insertions(+), 5 deletions(-) create mode 100644 dask_deltatable/_schema.py create mode 100644 dask_deltatable/write.py create mode 100644 tests/test_write.py diff --git a/.flake8 b/.flake8 index 7f39e8d..7a214b1 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ # flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234 [flake8] -max-line-length = 104 +max-line-length = 120 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c691a09..9f691b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,11 +30,12 @@ repos: args: [--warn-unused-configs] additional_dependencies: # Type stubs - - types-setuptools - boto3-stubs - - pytest - dask - deltalake + - pandas-stubs + - pytest + - types-setuptools - repo: https://github.com/pycqa/flake8 rev: 6.0.0 hooks: diff --git a/dask_deltatable/_schema.py b/dask_deltatable/_schema.py new file mode 100644 index 0000000..c85d7d5 --- /dev/null +++ b/dask_deltatable/_schema.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +""" +Most of this code was taken from + +https://github.com/data-engineering-collective/plateau + +https://github.com/data-engineering-collective/plateau/blob/d4c4522f5a829d43e3368fc82e1568c91fa352f3/plateau/core/common_metadata.py + +and adapted to this project + +under the original license + +MIT License + +Copyright (c) 2022 The plateau contributors. +Copyright (c) 2020-2021 The kartothek contributors. +Copyright (c) 2019 JDA Software, Inc + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" +import difflib +import json +import logging +import pprint +from copy import deepcopy +from typing import Iterable + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + +_logger = logging.getLogger() + + +class SchemaWrapper: + def __init__(self, schema: pa.Schema): + self.schema = schema + + def __hash__(self): + # FIXME: pyarrow raises a "cannot hash type dict" error + return hash(_schema2bytes(self.schema)) + + +def _pandas_in_schemas(schemas): + """Check if any schema contains pandas metadata.""" + has_pandas = False + for schema in schemas: + if schema.metadata and b"pandas" in schema.metadata: + has_pandas = True + return has_pandas + + +def _determine_schemas_to_compare( + schemas: Iterable[pa.Schema], ignore_pandas: bool +) -> tuple[pa.Schema | None, list[tuple[pa.Schema, list[str]]]]: + """Iterate over a list of `pyarrow.Schema` objects and prepares them for + comparison by picking a reference and determining all null columns. + + .. note:: + + If pandas metadata exists, the version stored in the metadata is overwritten with the currently + installed version since we expect to stay backwards compatible + + Returns + ------- + reference: Schema + A reference schema which is picked from the input list. The reference schema is guaranteed + to be a schema having the least number of null columns of all input columns. The set of null + columns is guaranteed to be a true subset of all null columns of all input schemas. If no such + schema can be found, an Exception is raised + list_of_schemas: List[Tuple[Schema, List]] + A list holding pairs of (Schema, null_columns) where the null_columns are all columns which are null and + must be removed before comparing the schemas + """ + has_pandas = _pandas_in_schemas(schemas) and not ignore_pandas + schemas_to_evaluate: list[tuple[pa.Schema, list[str]]] = [] + reference = None + null_cols_in_reference = set() + # Hashing the schemas is a very fast way to reduce the number of schemas to + # actually compare since in most circumstances this reduces to very few + # (which differ in e.g. null columns) + for schema_wrapped in set(map(SchemaWrapper, schemas)): + schema = schema_wrapped.schema + del schema_wrapped + if has_pandas: + metadata = schema.metadata + if metadata is None or b"pandas" not in metadata: + raise ValueError( + "Pandas and non-Pandas schemas are not comparable. " + "Use ignore_pandas=True if you only want to compare " + "on Arrow level." + ) + pandas_metadata = json.loads(metadata[b"pandas"].decode("utf8")) + + # we don't care about the pandas version, since we assume it's safe + # to read datasets that were written by older or newer versions. + pandas_metadata["pandas_version"] = f"{pd.__version__}" + + metadata_clean = deepcopy(metadata) + metadata_clean[b"pandas"] = _dict_to_binary(pandas_metadata) + current = pa.schema(schema, metadata_clean) + else: + current = schema + + # If a field is null we cannot compare it and must therefore reject it + null_columns = {field.name for field in current if field.type == pa.null()} + + # Determine a valid reference schema. A valid reference schema is considered to be the schema + # of all input schemas with the least empty columns. + # The reference schema ought to be a schema whose empty columns are a true subset for all sets + # of empty columns. This ensures that the actual reference schema is the schema with the most + # information possible. A schema which doesn't fulfil this requirement would weaken the + # comparison and would allow for false positives + + # Trivial case + if reference is None: + reference = current + null_cols_in_reference = null_columns + # The reference has enough information to validate against current schema. + # Append it to the list of schemas to be verified + elif null_cols_in_reference.issubset(null_columns): + schemas_to_evaluate.append((current, null_columns)) + # current schema includes all information of reference and more. + # Add reference to schemas_to_evaluate and update reference + elif null_columns.issubset(null_cols_in_reference): + schemas_to_evaluate.append((reference, list(null_cols_in_reference))) + reference = current + null_cols_in_reference = null_columns + # If there is no clear subset available elect the schema with the least null columns as `reference`. + # Iterate over the null columns of `reference` and replace it with a non-null field of the `current` + # schema which recovers the loop invariant (null columns of `reference` is subset of `current`) + else: + if len(null_columns) < len(null_cols_in_reference): + reference, current = current, reference + null_cols_in_reference, null_columns = ( + null_columns, + null_cols_in_reference, + ) + + for col in null_cols_in_reference - null_columns: + # Enrich the information in the reference by grabbing the missing fields + # from the current iteration. This assumes that we only check for global validity and + # isn't relevant where the reference comes from. + reference = _swap_fields_by_name(reference, current, col) + null_cols_in_reference.remove(col) + schemas_to_evaluate.append((current, null_columns)) + + assert (reference is not None) or (not schemas_to_evaluate) + + return reference, schemas_to_evaluate + + +def _swap_fields_by_name(reference, current, field_name): + current_field = current.field(field_name) + reference_index = reference.get_field_index(field_name) + return reference.set(reference_index, current_field) + + +def _strip_columns_from_schema(schema, field_names): + stripped_schema = schema + + for name in field_names: + ix = stripped_schema.get_field_index(name) + if ix >= 0: + stripped_schema = stripped_schema.remove(ix) + else: + # If the returned index is negative, the field doesn't exist in the schema. + # This is most likely an indicator for incompatible schemas and we refuse to strip the schema + # to not obfurscate the validation result + _logger.warning( + "Unexpected field `%s` encountered while trying to strip `null` columns.\n" + "Schema was:\n\n`%s`" % (name, schema) + ) + return schema + return stripped_schema + + +def _schema2bytes(schema: SchemaWrapper) -> bytes: + buf = pa.BufferOutputStream() + pq.write_metadata(schema, buf, coerce_timestamps="us") + return buf.getvalue().to_pybytes() + + +def _remove_diff_header(diff): + diff = list(diff) + for ix, el in enumerate(diff): + # This marks the first actual entry of the diff + # e.g. @@ -1,5 + 2,5 @@ + if el.startswith("@"): + return diff[ix:] + return diff + + +def _diff_schemas(first, second): + # see https://issues.apache.org/jira/browse/ARROW-4176 + + first_pyarrow_info = str(first.remove_metadata()) + second_pyarrow_info = str(second.remove_metadata()) + pyarrow_diff = _remove_diff_header( + difflib.unified_diff( + str(first_pyarrow_info).splitlines(), str(second_pyarrow_info).splitlines() + ) + ) + + first_pandas_info = first.pandas_metadata + second_pandas_info = second.pandas_metadata + pandas_meta_diff = _remove_diff_header( + difflib.unified_diff( + pprint.pformat(first_pandas_info).splitlines(), + pprint.pformat(second_pandas_info).splitlines(), + ) + ) + + diff_string = ( + "Arrow schema:\n" + + "\n".join(pyarrow_diff) + + "\n\nPandas_metadata:\n" + + "\n".join(pandas_meta_diff) + ) + + return diff_string + + +def validate_compatible( + schemas: Iterable[pa.Schema], ignore_pandas: bool = False +) -> pa.Schema: + """Validate that all schemas in a given list are compatible. + + Apart from the pandas version preserved in the schema metadata, schemas must be completely identical. That includes + a perfect match of the whole metadata (except the pandas version) and pyarrow types. + + In the case that all schemas don't contain any pandas metadata, we will check the Arrow + schemas directly for compatibility. + + Parameters + ---------- + schemas: List[Schema] + Schema information from multiple sources, e.g. multiple partitions. List may be empty. + ignore_pandas: bool + Ignore the schema information given by Pandas an always use the Arrow schema. + + Returns + ------- + schema: SchemaWrapper + The reference schema which was tested against + + Raises + ------ + ValueError + At least two schemas are incompatible. + """ + reference, schemas_to_evaluate = _determine_schemas_to_compare( + schemas, ignore_pandas + ) + + for current, null_columns in schemas_to_evaluate: + # We have schemas so the reference schema should be non-none. + assert reference is not None + # Compare each schema to the reference but ignore the null_cols and the Pandas schema information. + reference_to_compare = _strip_columns_from_schema( + reference, null_columns + ).remove_metadata() + current_to_compare = _strip_columns_from_schema( + current, null_columns + ).remove_metadata() + + def _fmt_origin(origin): + origin = sorted(origin) + # dask cuts of exception messages at 1k chars: + # https://github.com/dask/distributed/blob/6e0c0a6b90b1d3c/distributed/core.py#L964 + # therefore, we cut the the maximum length + max_len = 200 + inner_msg = ", ".join(origin) + ellipsis = "..." + if len(inner_msg) > max_len + len(ellipsis): + inner_msg = inner_msg[:max_len] + ellipsis + return f"{{{inner_msg}}}" + + if reference_to_compare != current_to_compare: + schema_diff = _diff_schemas(reference, current) + exception_message = """Schema violation + +Origin schema: {origin_schema} +Origin reference: {origin_reference} + +Diff: +{schema_diff} + +Reference schema: +{reference}""".format( + schema_diff=schema_diff, + reference=str(reference), + origin_schema=_fmt_origin(current.origin), + origin_reference=_fmt_origin(reference.origin), + ) + raise ValueError(exception_message) + + # add all origins to result AFTER error checking, otherwise the error message would be pretty misleading due to the + # reference containing all origins. + if reference is None: + return None + else: + return reference + + +def _dict_to_binary(dct): + return json.dumps(dct, sort_keys=True).encode("utf8") diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index ea875a9..7f080bb 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -8,12 +8,12 @@ import dask import dask.dataframe as dd -import pyarrow.parquet as pq # type: ignore[import] +import pyarrow.parquet as pq from dask.base import tokenize from dask.dataframe.utils import make_meta from dask.delayed import delayed from deltalake import DataCatalog, DeltaTable -from fsspec.core import get_fs_token_paths # type: ignore[import] +from fsspec.core import get_fs_token_paths from pyarrow import dataset as pa_ds diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py new file mode 100644 index 0000000..80bdeeb --- /dev/null +++ b/dask_deltatable/write.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import json +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Literal, Mapping + +import dask.dataframe as dd +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.fs as pa_fs +from dask.core import flatten +from dask.dataframe.core import Scalar +from dask.highlevelgraph import HighLevelGraph +from deltalake import DeltaTable +from deltalake.writer import ( + MAX_SUPPORTED_WRITER_VERSION, + PYARROW_MAJOR_VERSION, + AddAction, + DeltaJSONEncoder, + DeltaProtocolError, + DeltaStorageHandler, + __enforce_append_only, + _write_new_deltalake, + get_file_stats_from_metadata, + get_partitions_from_path, + try_get_table_and_table_uri, +) +from toolz.itertoolz import pluck + +from ._schema import validate_compatible + + +def to_deltalake( + table_or_uri: str | Path | DeltaTable, + df: dd.DataFrame, + *, + schema: pa.Schema | None = None, + partition_by: list[str] | str | None = None, + filesystem: pa_fs.FileSystem | None = None, + mode: Literal["error", "append", "overwrite", "ignore"] = "error", + file_options: ds.ParquetFileWriteOptions | None = None, + max_partitions: int | None = None, + max_open_files: int = 1024, + max_rows_per_file: int = 10 * 1024 * 1024, + min_rows_per_group: int = 64 * 1024, + max_rows_per_group: int = 128 * 1024, + name: str | None = None, + description: str | None = None, + configuration: Mapping[str, str | None] | None = None, + overwrite_schema: bool = False, + storage_options: dict[str, str] | None = None, + partition_filters: list[tuple[str, str, Any]] | None = None, +): + """Write a given dask.DataFrame to a delta table + + TODO: + """ + table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) + + # We need to write against the latest table version + if table: + table.update_incremental() + + __enforce_append_only(table=table, configuration=configuration, mode=mode) + + if filesystem is None: + if table is not None: + storage_options = table._storage_options or {} + storage_options.update(storage_options or {}) + + filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + + if isinstance(partition_by, str): + partition_by = [partition_by] + + if table: # already exists + if schema != table.schema().to_pyarrow() and not ( + mode == "overwrite" and overwrite_schema + ): + raise ValueError( + "Schema of data does not match table schema\n" + f"Table schema:\n{schema}\nData Schema:\n{table.schema().to_pyarrow()}" + ) + + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return + + current_version = table.version() + + if partition_by: + assert partition_by == table.metadata().partition_columns + else: + partition_by = table.metadata().partition_columns + + if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION: + raise DeltaProtocolError( + "This table's min_writer_version is " + f"{table.protocol().min_writer_version}, " + f"but this method only supports version {MAX_SUPPORTED_WRITER_VERSION}." + ) + else: # creating a new table + current_version = -1 + + # FIXME: schema is only known at this point if provided by the user + if partition_by and schema: + partition_schema = pa.schema([schema.field(name) for name in partition_by]) + partitioning = ds.partitioning(partition_schema, flavor="hive") + else: + if partition_by: + raise NotImplementedError("Have to provide schema when using partition_by") + partitioning = None + if mode == "overwrite": + # FIXME: There are a couple of checks that are not migrated yet + raise NotImplementedError("mode='overwrite' is not implemented") + + written = df.map_partitions( + _write_partition, + schema=schema, + partitioning=partitioning, + current_version=current_version, + file_options=file_options, + max_open_files=max_open_files, + max_rows_per_file=max_rows_per_file, + min_rows_per_group=min_rows_per_group, + max_rows_per_group=max_rows_per_group, + filesystem=filesystem, + max_partitions=max_partitions, + meta=(None, object), + ) + final_name = "delta-commit" + dsk = { + (final_name, 0): ( + _commit, + table, + written.__dask_keys__(), + table_uri, + schema, + mode, + partition_by, + name, + description, + configuration, + storage_options, + partition_filters, + ) + } + graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=(written,)) + return Scalar(graph, final_name, "") + + +def _commit( + table, + schemas_add_actions_nested, + table_uri, + schema, + mode, + partition_by, + name, + description, + configuration, + storage_options, + partition_filters, +): + schemas = list(flatten(pluck(0, schemas_add_actions_nested))) + add_actions = list(flatten(pluck(1, schemas_add_actions_nested))) + # TODO: What should the behavior be if the schema is provided? Cast the + # data? + if schema: + schemas.append(schema) + + # TODO: This is applying a potentially stricted schema control than what + # Delta requires but if this passes, it should be good to go + schema = validate_compatible(schemas) + assert schema + if table is None: + _write_new_deltalake( + table_uri, + schema, + add_actions, + mode, + partition_by or [], + name, + description, + configuration, + storage_options, + ) + else: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + schema, + partition_filters, + ) + table.update_incremental() + + +def _write_partition( + df, + *, + schema, + partitioning, + current_version, + file_options, + max_open_files, + max_rows_per_file, + min_rows_per_group, + max_rows_per_group, + filesystem, + max_partitions, +) -> tuple[pa.Schema, list[AddAction]]: + # TODO: what to do with the schema, if provided + data = pa.Table.from_pandas(df) + schema = schema or data.schema + + add_actions: list[AddAction] = [] + + def visitor(written_file: Any) -> None: + path, partition_values = get_partitions_from_path(written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + # PyArrow added support for written_file.size in 9.0.0 + if PYARROW_MAJOR_VERSION >= 9: + size = written_file.size + else: + size = filesystem.get_file_info([path])[0].size + + add_actions.append( + AddAction( + path, + size, + partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + ds.write_dataset( + data, + base_dir="/", + basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=partitioning, + # It will not accept a schema if using a RBR + schema=schema, + existing_data_behavior="overwrite_or_ignore", + file_options=file_options, + max_open_files=max_open_files, + file_visitor=visitor, + max_rows_per_file=max_rows_per_file, + min_rows_per_group=min_rows_per_group, + max_rows_per_group=max_rows_per_group, + filesystem=filesystem, + max_partitions=max_partitions, + ) + return schema, add_actions diff --git a/pyproject.toml b/pyproject.toml index 7cd19ee..2565cd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ allow_incomplete_defs = true allow_untyped_defs = true warn_return_any = false disallow_untyped_calls = false +ignore_missing_imports = true [tool.isort] profile = "black" diff --git a/tests/test_write.py b/tests/test_write.py new file mode 100644 index 0000000..314092b --- /dev/null +++ b/tests/test_write.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import os + +import pytest +from dask.dataframe.utils import assert_eq +from dask.datasets import timeseries + +from dask_deltatable import read_delta_table +from dask_deltatable.write import to_deltalake + + +@pytest.mark.parametrize( + "with_index", + [ + pytest.param( + True, + marks=[ + pytest.mark.xfail( + reason="TS index is always ns resolution but delta can only handle us" + ) + ], + ), + False, + ], +) +def test_roundtrip(tmpdir, with_index): + dtypes = { + "str": object, + # FIXME: Categorical data does not work + # "category": "category", + "float": float, + "int": int, + } + tmpdir = str(tmpdir) + ddf = timeseries( + start="2023-01-01", + end="2023-01-15", + # FIXME: Setting the partition frequency destroys the roundtrip for some + # reason + # partition_freq="1w", + dtypes=dtypes, + ) + # FIXME: us is the only precision delta supports. This lib should likely + # cast this itself + + ddf = ddf.reset_index() + ddf.timestamp = ddf.timestamp.astype("datetime64[us]") + if with_index: + ddf = ddf.set_index("timestamp") + + out = to_deltalake(tmpdir, ddf) + assert not os.listdir(tmpdir) + out.compute() + assert len(os.listdir(tmpdir)) > 0 + + ddf_read = read_delta_table(tmpdir) + # FIXME: The index is not recovered + if with_index: + ddf = ddf.reset_index() + + # By default, arrow reads with ns resolution + ddf.timestamp = ddf.timestamp.astype("datetime64[ns]") + assert_eq(ddf, ddf_read)