Skip to content

Commit

Permalink
Allow to pass types_mapper and other options into to_pandas.
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet committed Jul 13, 2023
1 parent 045481b commit 9df17d9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
28 changes: 24 additions & 4 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ def __init__(
path, storage_options=storage_options
)
self.schema = self.dt.schema().to_pyarrow()
meta = make_meta(self.schema.empty_table().to_pandas())

def meta(self, **kwargs):
"""Pass kwargs to `to_pandas` call when creating the metadata"""
meta = make_meta(self.schema.empty_table().to_pandas(**kwargs))
if self.columns:
meta = meta[self.columns]
self.meta = meta
return meta

def read_delta_dataset(self, f: str, **kwargs: dict[Any, Any]):
schema = kwargs.pop("schema", None) or self.schema
filter = kwargs.pop("filter", None)
filter_expression = filters_to_expression(filter) if filter else None
to_pandas_kwargs = kwargs.pop("pyarrow_to_pandas", {})
return (
pa_ds.dataset(
source=f,
Expand All @@ -76,7 +80,7 @@ def read_delta_dataset(self, f: str, **kwargs: dict[Any, Any]):
partitioning="hive",
)
.to_table(filter=filter_expression, columns=self.columns)
.to_pandas()
.to_pandas(**to_pandas_kwargs)
)

def _history_helper(self, log_file_name: str):
Expand Down Expand Up @@ -185,10 +189,12 @@ def read_delta_table(self, **kwargs) -> dd.core.DataFrame:
if len(pq_files) == 0:
raise RuntimeError("No Parquet files are available")

meta = self.meta(**kwargs.get("pyarrow_to_pandas", {}))

return dd.from_map(
partial(self.read_delta_dataset, **kwargs),
pq_files,
meta=self.meta,
meta=meta,
label="read-delta-table",
token=tokenize(self.fs_token, **kwargs),
)
Expand Down Expand Up @@ -271,6 +277,7 @@ def read_delta_table(
Some most used parameters can be passed here are:
1. schema
2. filter
3. pyarrow_to_pandas
schema : pyarrow.Schema
Used to maintain schema evolution in deltatable.
Expand All @@ -286,6 +293,19 @@ def read_delta_table(
example:
[("x",">",400)] --> pyarrow.dataset.field("x")>400
pyarrow_to_pandas: dict
Options to pass directly to pyarrow.Table.to_pandas.
Common options include:
* categories: list[str]
List of columns to treat as pandas.Categorical
* strings_to_categorical: bool
Encode string (UTF8) and binary types to pandas.Categorical.
* types_mapper: Callable
A function mapping a pyarrow DataType to a pandas ExtensionDtype
See https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
for more.
Returns
-------
Dask.DataFrame
Expand Down
44 changes: 37 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,70 @@
import zipfile
from unittest.mock import MagicMock, patch

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from deltalake import DeltaTable

import dask_deltatable as ddt

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT_DIR, "data")


@pytest.fixture()
def simple_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/simple.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test1/"


@pytest.fixture()
def simple_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/simple2.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/simple_table/"


@pytest.fixture()
def partition_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/partition.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/partition.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test2/"


@pytest.fixture()
def empty_table1(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/empty1.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty1.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty/"


@pytest.fixture()
def empty_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/empty2.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty2/"


@pytest.fixture()
def checkpoint_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/checkpoint.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/checkpoint.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/checkpoint/"


@pytest.fixture()
def vacuum_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile("tests/data/vacuum.zip")
deltaf = zipfile.ZipFile(f"{DATA_DIR}/vacuum.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/vaccum_table"

Expand All @@ -75,6 +80,31 @@ def test_read_delta(simple_table):
assert df.compute().shape == (200, 4)


def test_read_delta_types_mapper(simple_table):
"""Provide a custom types mapper"""

def types_mapper(pyarrow_dtype):
if pyarrow_dtype == pa.int64():
return pd.Int32Dtype()

df = ddt.read_delta_table(
simple_table, pyarrow_to_pandas={"types_mapper": types_mapper}
)
assert df.dtypes["id"] == "Int32"
assert df.dtypes["count"] == "Int32"
res = df.compute()
assert res.dtypes["id"] == "Int32"
assert res.dtypes["count"] == "Int32"


def test_read_delta_categories(simple_table):
"""Provide a list of categories"""
df = ddt.read_delta_table(simple_table, pyarrow_to_pandas={"categories": ["id"]})
assert df.dtypes["id"] == "category"
res = df.compute()
assert res.dtypes["id"] == "category"


def test_read_delta_with_different_versions(simple_table):
print(simple_table)
df = ddt.read_delta_table(simple_table, version=0)
Expand Down

0 comments on commit 9df17d9

Please sign in to comment.