Skip to content

Commit

Permalink
Allow to pass types_mapper and other options into to_pandas (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet authored Jul 14, 2023
1 parent 2b30e31 commit 215bbde
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 21 deletions.
28 changes: 24 additions & 4 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,19 @@ def __init__(
path, storage_options=storage_options
)
self.schema = self.dt.schema().to_pyarrow()
meta = make_meta(self.schema.empty_table().to_pandas())

def meta(self, **kwargs):
"""Pass kwargs to `to_pandas` call when creating the metadata"""
meta = make_meta(self.schema.empty_table().to_pandas(**kwargs))
if self.columns:
meta = meta[self.columns]
self.meta = meta
return meta

def read_delta_dataset(self, f: str, **kwargs: dict[Any, Any]):
schema = kwargs.pop("schema", None) or self.schema
filter = kwargs.pop("filter", None)
filter_expression = filters_to_expression(filter) if filter else None
to_pandas_kwargs = kwargs.pop("pyarrow_to_pandas", {})
return (
pa_ds.dataset(
source=f,
Expand All @@ -72,7 +76,7 @@ def read_delta_dataset(self, f: str, **kwargs: dict[Any, Any]):
partitioning="hive",
)
.to_table(filter=filter_expression, columns=self.columns)
.to_pandas()
.to_pandas(**to_pandas_kwargs)
)

def get_pq_files(self, filter: Filters = None) -> list[str]:
Expand Down Expand Up @@ -112,10 +116,12 @@ def read_delta_table(self, **kwargs) -> dd.core.DataFrame:
if len(pq_files) == 0:
raise RuntimeError("No Parquet files are available")

meta = self.meta(**kwargs.get("pyarrow_to_pandas", {}))

return dd.from_map(
partial(self.read_delta_dataset, **kwargs),
pq_files,
meta=self.meta,
meta=meta,
label="read-delta-table",
token=tokenize(self.fs_token, **kwargs),
)
Expand Down Expand Up @@ -198,6 +204,7 @@ def read_deltalake(
Some most used parameters can be passed here are:
1. schema
2. filter
3. pyarrow_to_pandas
schema : pyarrow.Schema
Used to maintain schema evolution in deltatable.
Expand All @@ -213,6 +220,19 @@ def read_deltalake(
example:
[("x",">",400)] --> pyarrow.dataset.field("x")>400
pyarrow_to_pandas: dict
Options to pass directly to pyarrow.Table.to_pandas.
Common options include:
* categories: list[str]
List of columns to treat as pandas.Categorical
* strings_to_categorical: bool
Encode string (UTF8) and binary types to pandas.Categorical.
* types_mapper: Callable
A function mapping a pyarrow DataType to a pandas ExtensionDtype
See https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
for more.
Returns
-------
Dask.DataFrame
Expand Down
22 changes: 11 additions & 11 deletions tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def download_data():


def test_reader_all_primitive_types():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/all_primitive_types/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/all_primitive_types/expected/latest/table_content/*parquet"
)
Expand All @@ -52,7 +52,7 @@ def test_reader_all_primitive_types():

@pytest.mark.parametrize("version,subdir", [(None, "latest"), (0, "v0"), (1, "v1")])
def test_reader_basic_append(version, subdir):
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/basic_append/delta", version=version)
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/basic_append/delta", version=version)
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/basic_append/expected/{subdir}/table_content/*parquet"
)
Expand All @@ -61,7 +61,7 @@ def test_reader_basic_append(version, subdir):

@pytest.mark.parametrize("version,subdir", [(None, "latest"), (0, "v0"), (1, "v1")])
def test_reader_basic_partitioned(version, subdir):
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/basic_partitioned/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/basic_partitioned/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/basic_partitioned/expected/latest/table_content/*parquet"
)
Expand All @@ -73,7 +73,7 @@ def test_reader_basic_partitioned(version, subdir):
"version,subdir", [(None, "latest"), (0, "v0"), (1, "v1"), (2, "v2")]
)
def test_reader_multi_partitioned(version, subdir):
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/multi_partitioned/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/multi_partitioned/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/multi_partitioned/expected/{subdir}/table_content/*parquet"
)
Expand All @@ -82,47 +82,47 @@ def test_reader_multi_partitioned(version, subdir):

@pytest.mark.xfail(reason="https://github.com/delta-io/delta-rs/issues/1533")
def test_reader_multi_partitioned_2():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/multi_partitioned_2/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/multi_partitioned_2/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/multi_partitioned_2/expected/latest/table_content/*parquet"
)
assert_eq(actual_ddf, expected_ddf)


def test_reader_nested_types():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/nested_types/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/nested_types/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/nested_types/expected/latest/table_content/*parquet"
)
assert_eq(actual_ddf, expected_ddf)


def test_reader_no_replay():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/no_replay/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/no_replay/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/no_replay/expected/latest/table_content/*parquet"
)
assert_eq(actual_ddf, expected_ddf)


def test_reader_no_stats():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/no_stats/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/no_stats/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/no_stats/expected/latest/table_content/*parquet"
)
assert_eq(actual_ddf, expected_ddf)


def test_reader_stats_as_structs():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/stats_as_struct/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/stats_as_struct/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/stats_as_struct/expected/latest/table_content/*parquet"
)
assert_eq(actual_ddf, expected_ddf)


def test_reader_with_checkpoint():
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/with_checkpoint/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/with_checkpoint/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/with_checkpoint/expected/latest/table_content/*parquet"
)
Expand All @@ -131,7 +131,7 @@ def test_reader_with_checkpoint():

@pytest.mark.parametrize("version,subdir", [(None, "latest"), (1, "v1")])
def test_reader_with_schema_change(version, subdir):
actual_ddf = ddt.read_delta_table(f"{DATA_DIR}/with_schema_change/delta")
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/with_schema_change/delta")
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/with_schema_change/expected/{subdir}/table_content/*parquet"
)
Expand Down
42 changes: 36 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,62 @@
import zipfile
from unittest.mock import MagicMock, patch

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from deltalake import DeltaTable

import dask_deltatable as ddt

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT_DIR, "data")


@pytest.fixture()
def simple_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/simple.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test1/"


@pytest.fixture()
def simple_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/simple2.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/simple_table/"


@pytest.fixture()
def partition_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/partition.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/partition.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test2/"


@pytest.fixture()
def empty_table1(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/empty1.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty1.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty/"


@pytest.fixture()
def empty_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/empty2.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty2/"


@pytest.fixture()
def checkpoint_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/checkpoint.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/checkpoint.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/checkpoint/"

Expand All @@ -67,6 +72,31 @@ def test_read_delta(simple_table):
assert df.compute().shape == (200, 4)


def test_read_delta_types_mapper(simple_table):
"""Provide a custom types mapper"""

def types_mapper(pyarrow_dtype):
if pyarrow_dtype == pa.int64():
return pd.Int32Dtype()

df = ddt.read_deltalake(
simple_table, pyarrow_to_pandas={"types_mapper": types_mapper}
)
assert df.dtypes["id"] == "Int32"
assert df.dtypes["count"] == "Int32"
res = df.compute()
assert res.dtypes["id"] == "Int32"
assert res.dtypes["count"] == "Int32"


def test_read_delta_categories(simple_table):
"""Provide a list of categories"""
df = ddt.read_deltalake(simple_table, pyarrow_to_pandas={"categories": ["id"]})
assert df.dtypes["id"] == "category"
res = df.compute()
assert res.dtypes["id"] == "category"


def test_read_delta_with_different_versions(simple_table):
print(simple_table)
df = ddt.read_deltalake(simple_table, version=0)
Expand Down

0 comments on commit 215bbde

Please sign in to comment.