Skip to content

Commit

Permalink
Fix the PickleError with RawDeltaTable (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet authored Jul 21, 2023
1 parent dbeb8cc commit 620b0d7
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 154 deletions.
57 changes: 57 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import os
import zipfile

import pytest

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT_DIR, "tests", "data")


@pytest.fixture()
def simple_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test1/"


@pytest.fixture()
def simple_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/simple_table/"


@pytest.fixture()
def partition_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/partition.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test2/"


@pytest.fixture()
def empty_table1(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty1.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty/"


@pytest.fixture()
def empty_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty2/"


@pytest.fixture()
def checkpoint_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/checkpoint.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/checkpoint/"
194 changes: 92 additions & 102 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
from collections.abc import Sequence
from functools import partial
from typing import Any
from typing import Any, cast

import dask.dataframe as dd
import pyarrow as pa
Expand All @@ -24,107 +25,96 @@
filters_to_expression = pq._filters_to_expression


class DeltaTableWrapper:
path: str
version: int | None
columns: list[str] | None
datetime: str | None
storage_options: dict[str, Any] | None

def __init__(
self,
path: str,
version: int | None,
columns: list[str] | None,
datetime: str | None = None,
storage_options: dict[str, str] | None = None,
delta_storage_options: dict[str, str] | None = None,
) -> None:
self.path: str = path
self.version: int = version
self.columns = columns
self.datetime = datetime
self.storage_options = storage_options
self.dt = DeltaTable(
table_uri=self.path,
version=self.version,
storage_options=delta_storage_options,
)
self.fs, self.fs_token, _ = get_fs_token_paths(
path, storage_options=storage_options
)
self.schema = self.dt.schema().to_pyarrow()

def meta(self, **kwargs):
"""Pass kwargs to `to_pandas` call when creating the metadata"""
meta = make_meta(self.schema.empty_table().to_pandas(**kwargs))
if self.columns:
meta = meta[self.columns]
return meta

def read_delta_dataset(self, f: str, **kwargs: dict[Any, Any]):
schema = kwargs.pop("schema", None) or self.schema
filter = kwargs.pop("filter", None)
filter_expression = filters_to_expression(filter) if filter else None
to_pandas_kwargs = kwargs.pop("pyarrow_to_pandas", {})
return (
pa_ds.dataset(
source=f,
schema=schema,
filesystem=self.fs,
format="parquet",
partitioning="hive",
)
.to_table(filter=filter_expression, columns=self.columns)
.to_pandas(**to_pandas_kwargs)
)
def _get_pq_files(dt: DeltaTable, filter: Filters = None) -> list[str]:
"""
Get the list of parquet files after loading the
current datetime version
def get_pq_files(self, filter: Filters = None) -> list[str]:
"""
Get the list of parquet files after loading the
current datetime version
Parameters
----------
filter : list[tuple[str, str, Any]] | list[list[tuple[str, str, Any]]] | None
Filters in DNF form.
Returns
-------
list[str]
List of files matching optional filter.
"""
__doc__ == self.dt.load_with_datetime.__doc__
if self.datetime is not None:
self.dt.load_with_datetime(self.datetime)
partition_filters = get_partition_filters(
self.dt.metadata().partition_columns, filter
)
if not partition_filters:
# can't filter
return self.dt.file_uris()
file_uris = set()
for filter_set in partition_filters:
file_uris.update(self.dt.file_uris(partition_filters=filter_set))
return sorted(list(file_uris))

def read_delta_table(self, **kwargs) -> dd.core.DataFrame:
"""
Reads the list of parquet files in parallel
"""
pq_files = self.get_pq_files(filter=kwargs.get("filter", None))
if len(pq_files) == 0:
raise RuntimeError("No Parquet files are available")

meta = self.meta(**kwargs.get("pyarrow_to_pandas", {}))

return dd.from_map(
partial(self.read_delta_dataset, **kwargs),
pq_files,
meta=meta,
label="read-delta-table",
token=tokenize(self.fs_token, **kwargs),
Parameters
----------
dt : DeltaTable
DeltaTable instance
filter : list[tuple[str, str, Any]] | list[list[tuple[str, str, Any]]] | None
Filters in DNF form.
Returns
-------
list[str]
List of files matching optional filter.
"""
partition_filters = get_partition_filters(dt.metadata().partition_columns, filter)
if not partition_filters:
# can't filter
return dt.file_uris()
file_uris = set()
for filter_set in partition_filters:
file_uris.update(dt.file_uris(partition_filters=filter_set))
return sorted(list(file_uris))


def _read_delta_partition(
filename: str,
schema: pa.Schema,
fs: Any,
columns: Sequence[str] | None,
filter: Filters = None,
pyarrow_to_pandas: dict[str, Any] | None = None,
**_kwargs: dict[str, Any],
):
filter_expression = filters_to_expression(filter) if filter else None
if pyarrow_to_pandas is None:
pyarrow_to_pandas = {}
return (
pa_ds.dataset(
source=filename,
schema=schema,
filesystem=fs,
format="parquet",
partitioning="hive",
)
.to_table(filter=filter_expression, columns=columns)
.to_pandas(**pyarrow_to_pandas)
)


def _read_from_filesystem(
path: str,
version: int | None,
columns: Sequence[str] | None,
datetime: str | None = None,
storage_options: dict[str, str] | None = None,
delta_storage_options: dict[str, str] | None = None,
**kwargs: dict[str, Any],
) -> dd.core.DataFrame:
"""
Reads the list of parquet files in parallel
"""
fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options)
dt = DeltaTable(
table_uri=path, version=version, storage_options=delta_storage_options
)
if datetime is not None:
dt.load_with_datetime(datetime)

schema = dt.schema().to_pyarrow()

filter_value = cast(Filters, kwargs.get("filter", None))
pq_files = _get_pq_files(dt, filter=filter_value)
if len(pq_files) == 0:
raise RuntimeError("No Parquet files are available")

mapper_kwargs = kwargs.get("pyarrow_to_pandas", {})
meta = make_meta(schema.empty_table().to_pandas(**mapper_kwargs))
if columns:
meta = meta[columns]

return dd.from_map(
partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs),
pq_files,
meta=meta,
label="read-delta-table",
token=tokenize(fs_token, **kwargs),
)


def _read_from_catalog(
Expand Down Expand Up @@ -255,13 +245,13 @@ def read_deltalake(
else:
if path is None:
raise ValueError("Please Provide Delta Table path")
dtw = DeltaTableWrapper(
resultdf = _read_from_filesystem(
path=path,
version=version,
columns=columns,
storage_options=storage_options,
datetime=datetime,
delta_storage_options=delta_storage_options,
**kwargs,
)
resultdf = dtw.read_delta_table(columns=columns, **kwargs)
return resultdf
52 changes: 0 additions & 52 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import glob
import os
import zipfile
from unittest.mock import MagicMock, patch

import pandas as pd
Expand All @@ -13,57 +12,6 @@

import dask_deltatable as ddt

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT_DIR, "data")


@pytest.fixture()
def simple_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test1/"


@pytest.fixture()
def simple_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/simple2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/simple_table/"


@pytest.fixture()
def partition_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/partition.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/test2/"


@pytest.fixture()
def empty_table1(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty1.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty/"


@pytest.fixture()
def empty_table2(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/empty2.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/empty2/"


@pytest.fixture()
def checkpoint_table(tmpdir):
output_dir = tmpdir
deltaf = zipfile.ZipFile(f"{DATA_DIR}/checkpoint.zip")
deltaf.extractall(output_dir)
return str(output_dir) + "/checkpoint/"


def test_read_delta(simple_table):
df = ddt.read_deltalake(simple_table)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ def test_write_with_schema(client, tmpdir):
ddt.to_deltalake(f"{tmpdir}", ddf, schema=schema)
ds = pa_ds.dataset(str(tmpdir))
assert ds.schema == schema


def test_read(client, simple_table):
df = ddt.read_deltalake(simple_table)
assert df.columns.tolist() == ["id", "count", "temperature", "newColumn"]
assert df.compute().shape == (200, 4)

0 comments on commit 620b0d7

Please sign in to comment.