Skip to content

Commit

Permalink
Pass a dict instead of ParquetFileWriteOptions that can't be pickled.
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet committed Jul 14, 2023
1 parent b6a675e commit f24ee92
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 4 deletions.
10 changes: 6 additions & 4 deletions dask_deltatable/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def to_deltalake(
partition_by: list[str] | str | None = None,
filesystem: pa_fs.FileSystem | None = None,
mode: Literal["error", "append", "overwrite", "ignore"] = "error",
file_options: ds.ParquetFileWriteOptions | None = None,
file_options: Mapping[str, Any] | None = None,
max_partitions: int | None = None,
max_open_files: int = 1024,
max_rows_per_file: int = 10 * 1024 * 1024,
Expand Down Expand Up @@ -78,9 +78,8 @@ def to_deltalake(
If 'append', will add new data.
If 'overwrite', will replace table with new data.
If 'ignore', will not write anything if table already exists.
file_options : ds.ParquetFileWriteOptions | None. Default None
Optional write options for Parquet (ParquetFileWriteOptions).
Can be provided with defaults using ParquetFileWriteOptions().make_write_options().
file_options : Mapping[str, Any] | None. Default None
Optional dict of options that can be used to initialize ParquetFileWriteOptions.
Please refer to https://github.com/apache/arrow/blob/master/python/pyarrow/_dataset_parquet.pyx
for the list of available options
max_partitions : int | None. Default None
Expand Down Expand Up @@ -315,6 +314,9 @@ def visitor(written_file: Any) -> None:
)
)

if file_options is not None:
file_options = ds.ParquetFileFormat().make_write_options(**file_options)

ds.write_dataset(
data,
base_dir="/",
Expand Down
25 changes: 25 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234
[flake8]
exclude = __init__.py
max-line-length = 120
ignore =
# Extra space in brackets
E20
# Multiple spaces around ","
E231,E241
# Comments
E26
# Import formatting
E4
# Comparing types instead of isinstance
E721
# Assigning lambda expression
E731
# Ambiguous variable names
E741
# Line break before binary operator
W503
# Line break after binary operator
W504
# Redefinition of unused 'loop' from line 10
F811
81 changes: 81 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import pytest

distributed = pytest.importorskip("distributed")

import os # noqa: E402
import sys # noqa: E402

import pyarrow as pa # noqa: E402
import pyarrow.dataset as pa_ds # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from dask.datasets import timeseries # noqa: E402
from distributed.utils_test import cleanup # noqa F401
from distributed.utils_test import ( # noqa F401
client,
cluster,
cluster_fixture,
gen_cluster,
loop,
loop_in_thread,
popen,
varying,
)

import dask_deltatable as ddt # noqa: E402

pytestmark = pytest.mark.skipif(
sys.platform == "win32",
reason=(
"The teardown of distributed.utils_test.cluster_fixture "
"fails on windows CI currently"
),
)


def test_write(client, tmpdir):
ddf = timeseries(
start="2023-01-01",
end="2023-01-03",
freq="1H",
partition_freq="1D",
dtypes={"str": object, "float": float, "int": int},
).reset_index()
ddt.to_deltalake(f"{tmpdir}", ddf)


def test_write_with_options(client, tmpdir):
file_options = dict(compression="gzip")
ddf = timeseries(
start="2023-01-01",
end="2023-01-03",
freq="1H",
partition_freq="1D",
dtypes={"str": object, "float": float, "int": int},
).reset_index()
ddt.to_deltalake(f"{tmpdir}", ddf, file_options=file_options)
parquet_filename = [f for f in os.listdir(tmpdir) if f.endswith(".parquet")][0]
parquet_file = pq.ParquetFile(f"{tmpdir}/{parquet_filename}")
assert parquet_file.metadata.row_group(0).column(0).compression == "GZIP"


def test_write_with_schema(client, tmpdir):
ddf = timeseries(
start="2023-01-01",
end="2023-01-03",
freq="1H",
partition_freq="1D",
dtypes={"str": object, "float": float, "int": int},
).reset_index()
schema = pa.schema(
[
pa.field("timestamp", pa.timestamp("us")),
pa.field("str", pa.string()),
pa.field("float", pa.float32()),
pa.field("int", pa.int32()),
]
)
ddt.to_deltalake(f"{tmpdir}", ddf, schema=schema)
ds = pa_ds.dataset(str(tmpdir))
assert ds.schema == schema

0 comments on commit f24ee92

Please sign in to comment.