Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Jun 19, 2024
1 parent 5ec60a4 commit c6cfca4
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings
from collections import OrderedDict
from functools import partial
from pathlib import Path
from typing import Any
from typing import Any, Callable

import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -275,7 +276,27 @@ def test_read_parquet_hive_schema_with_pyarrow() -> None:
pl.read_parquet("test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True)


def test_hive_partition_directory_scan(tmp_path: Path) -> None:
@pytest.mark.parametrize(
("scan_func", "write_func"),
[
(pl.scan_parquet, pl.DataFrame.write_parquet),
],
)
@pytest.mark.parametrize(
["append_glob", "glob"],
[
("**/*.bin", True),
("", True),
("", False),
],
)
def test_hive_partition_directory_scan(
tmp_path: Path,
append_glob: str,
scan_func: Callable[[Any], pl.LazyFrame],
write_func: Callable[[pl.DataFrame, Path], None],
glob: bool,
) -> None:
tmp_path.mkdir(exist_ok=True)

dfs = [
Expand All @@ -290,30 +311,35 @@ def test_hive_partition_directory_scan(tmp_path: Path) -> None:
b = df.item(0, "b")
path = tmp_path / f"a={a}/b={b}/data.bin"
path.parent.mkdir(exist_ok=True, parents=True)
df.drop("a", "b").write_parquet(path)
write_func(df.drop("a", "b"), path)

df = pl.concat(dfs)
hive_schema = df.lazy().select("a", "b").collect_schema()

out = pl.scan_parquet(
tmp_path, hive_partitioning=True, hive_schema=hive_schema
scan = scan_func
scan = partial(scan_func, hive_schema=hive_schema, glob=glob)

out = scan(
tmp_path / append_glob,
hive_partitioning=True,
hive_schema=hive_schema,
).collect()
assert_frame_equal(out, df)

out = pl.scan_parquet(
tmp_path, hive_partitioning=False, hive_schema=hive_schema
out = scan(
tmp_path / append_glob, hive_partitioning=False, hive_schema=hive_schema
).collect()
assert_frame_equal(out, df.drop("a", "b"))

out = pl.scan_parquet(
tmp_path / "a=1",
out = scan(
tmp_path / "a=1" / append_glob,
hive_partitioning=True,
hive_schema=hive_schema,
).collect()
assert_frame_equal(out, df.filter(a=1).drop("a"))

out = pl.scan_parquet(
tmp_path / "a=1",
out = scan(
tmp_path / "a=1" / append_glob,
hive_partitioning=False,
hive_schema=hive_schema,
).collect()
Expand All @@ -322,15 +348,11 @@ def test_hive_partition_directory_scan(tmp_path: Path) -> None:
path = tmp_path / "a=1/b=1/data.bin"

df = dfs[0]
out = pl.scan_parquet(
path, hive_partitioning=True, hive_schema=hive_schema
).collect()
out = scan(path, hive_partitioning=True, hive_schema=hive_schema).collect()

assert_frame_equal(out, df)

df = dfs[0].drop("a", "b")
out = pl.scan_parquet(
path, hive_partitioning=False, hive_schema=hive_schema
).collect()
out = scan(path, hive_partitioning=False, hive_schema=hive_schema).collect()

assert_frame_equal(out, df)

0 comments on commit c6cfca4

Please sign in to comment.