Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Jun 19, 2024
1 parent 14dd2ca commit 5ec60a4
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
6 changes: 2 additions & 4 deletions crates/polars-lazy/src/scan/file_list_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn expand_paths(
} else if !path.ends_with("/")
&& is_file_cloud(path.to_str().unwrap(), cloud_options)?
{
expand_start_idx = 0;
out_paths.push(path.clone());
continue;
} else if !glob {
Expand Down Expand Up @@ -120,15 +121,12 @@ fn expand_paths(
out_paths.push(path.map_err(to_compute_err)?);
}
} else {
expand_start_idx = 0;
out_paths.push(path.clone());
}
}
}

// Todo:
// This maintains existing behavior - will remove very soon.
expand_start_idx = 0;

Ok((
out_paths.into_iter().collect::<Arc<[_]>>(),
expand_start_idx,
Expand Down
19 changes: 12 additions & 7 deletions crates/polars-plan/src/plans/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,18 @@ impl HivePartitions {
}

let schema = match schema {
Some(s) => {
polars_ensure!(
s.len() == partitions.len(),
SchemaMismatch: "path does not match the provided Hive schema"
);
s
},
Some(schema) => Arc::new(
partitions
.iter()
.map(|s| {
let mut field = s.field().into_owned();
if let Some(dtype) = schema.get(field.name()) {
field.dtype = dtype.clone();
};
field
})
.collect::<Schema>(),
),
None => Arc::new(partitions.as_slice().into()),
};

Expand Down
63 changes: 62 additions & 1 deletion py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_hive_partitioned_err(io_files_path: Path, tmp_path: Path) -> None:
df.write_parquet(root / "file.parquet")

with pytest.raises(DuplicateError, match="invalid Hive partition schema"):
pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True).collect()
pl.scan_parquet(tmp_path, hive_partitioning=True).collect()


@pytest.mark.write_disk()
Expand Down Expand Up @@ -273,3 +273,64 @@ def test_read_parquet_hive_schema_with_pyarrow() -> None:
match="cannot use `hive_partitions` with `use_pyarrow=True`",
):
pl.read_parquet("test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True)


def test_hive_partition_directory_scan(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

dfs = [
pl.DataFrame({'x': 5 * [1], 'a': 1, 'b': 1}),
pl.DataFrame({'x': 5 * [2], 'a': 1, 'b': 2}),
pl.DataFrame({'x': 5 * [3], 'a': 2, 'b': 1}),
pl.DataFrame({'x': 5 * [4], 'a': 2, 'b': 2}),
] # fmt: skip

for df in dfs:
a = df.item(0, "a")
b = df.item(0, "b")
path = tmp_path / f"a={a}/b={b}/data.bin"
path.parent.mkdir(exist_ok=True, parents=True)
df.drop("a", "b").write_parquet(path)

df = pl.concat(dfs)
hive_schema = df.lazy().select("a", "b").collect_schema()

out = pl.scan_parquet(
tmp_path, hive_partitioning=True, hive_schema=hive_schema
).collect()
assert_frame_equal(out, df)

out = pl.scan_parquet(
tmp_path, hive_partitioning=False, hive_schema=hive_schema
).collect()
assert_frame_equal(out, df.drop("a", "b"))

out = pl.scan_parquet(
tmp_path / "a=1",
hive_partitioning=True,
hive_schema=hive_schema,
).collect()
assert_frame_equal(out, df.filter(a=1).drop("a"))

out = pl.scan_parquet(
tmp_path / "a=1",
hive_partitioning=False,
hive_schema=hive_schema,
).collect()
assert_frame_equal(out, df.filter(a=1).drop("a", "b"))

path = tmp_path / "a=1/b=1/data.bin"

df = dfs[0]
out = pl.scan_parquet(
path, hive_partitioning=True, hive_schema=hive_schema
).collect()

assert_frame_equal(out, df)

df = dfs[0].drop("a", "b")
out = pl.scan_parquet(
path, hive_partitioning=False, hive_schema=hive_schema
).collect()

assert_frame_equal(out, df)

0 comments on commit 5ec60a4

Please sign in to comment.