Skip to content

Commit

Permalink
exclude partial hive columns from projection
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Sep 9, 2024
1 parent d12b89a commit ebbc27b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
41 changes: 20 additions & 21 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,27 +1068,26 @@ pub(crate) fn maybe_init_projection_excluding_hive(
// Update `with_columns` with a projection so that hive columns aren't loaded from the
// file
let hive_parts = hive_parts?;

let hive_schema = hive_parts.schema();

let (first_hive_name, _) = hive_schema.get_at_index(0)?;

// TODO: Optimize this
let names = match reader_schema {
Either::Left(ref v) => v
.contains(first_hive_name.as_str())
.then(|| v.iter_names_cloned().collect::<Vec<_>>()),
Either::Right(ref v) => v
.contains(first_hive_name.as_str())
.then(|| v.iter_names_cloned().collect()),
};

let names = names?;

Some(
names
.into_iter()
.filter(|x| !hive_schema.contains(x))
.collect::<Arc<[_]>>(),
)
match &reader_schema {
Either::Left(reader_schema) => hive_schema
.iter_names()
.any(|x| reader_schema.contains(x))
.then(|| {
reader_schema
.iter_names_cloned()
.filter(|x| !hive_schema.contains(x))
.collect::<Arc<[_]>>()
}),
Either::Right(reader_schema) => hive_schema
.iter_names()
.any(|x| reader_schema.contains(x))
.then(|| {
reader_schema
.iter_names_cloned()
.filter(|x| !hive_schema.contains(x))
.collect::<Arc<[_]>>()
}),
}
}
15 changes: 13 additions & 2 deletions py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,14 @@ def assert_with_projections(lf: pl.LazyFrame, df: pl.DataFrame) -> None:
)
write_func(df, partial_path)

rhs = rhs.select(
pl.col("x").cast(pl.Int32),
pl.col("b").cast(pl.Int16),
pl.col("y").cast(pl.Int32),
pl.col("a").cast(pl.Int64),
)

lf = scan_func(partial_path, hive_partitioning=True) # type: ignore[call-arg]
rhs = df
assert_frame_equal(lf.collect(projection_pushdown=projection_pushdown), rhs)
assert_with_projections(lf, rhs)

Expand All @@ -572,7 +578,12 @@ def assert_with_projections(lf: pl.LazyFrame, df: pl.DataFrame) -> None:
hive_schema={"a": pl.String, "b": pl.String},
hive_partitioning=True,
)
rhs = df.with_columns(pl.col("a", "b").cast(pl.String))
rhs = rhs.select(
pl.col("x").cast(pl.Int32),
pl.col("b").cast(pl.String),
pl.col("y").cast(pl.Int32),
pl.col("a").cast(pl.String),
)
assert_frame_equal(
lf.collect(projection_pushdown=projection_pushdown),
rhs,
Expand Down

0 comments on commit ebbc27b

Please sign in to comment.