Skip to content

Commit

Permalink
fix(python): Respect strict argument (#17990)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 1, 2024
1 parent d0a291e commit ed0f313
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 5 deletions.
15 changes: 10 additions & 5 deletions py-polars/src/conversion/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,18 @@ pub(crate) fn py_object_to_any_value<'py>(
}

fn get_list(ob: &Bound<'_, PyAny>, strict: bool) -> PyResult<AnyValue<'static>> {
fn get_list_with_constructor(ob: &Bound<'_, PyAny>) -> PyResult<AnyValue<'static>> {
fn get_list_with_constructor(
ob: &Bound<'_, PyAny>,
strict: bool,
) -> PyResult<AnyValue<'static>> {
// Use the dedicated constructor.
// This constructor is able to go via dedicated type constructors
// so it can be much faster.
let py = ob.py();
let s = SERIES.call1(py, (ob,))?;
get_list_from_series(s.bind(py), true)
let kwargs = PyDict::new_bound(py);
kwargs.set_item("strict", strict)?;
let s = SERIES.call_bound(py, (ob,), Some(&kwargs))?;
get_list_from_series(s.bind(py), strict)
}

if ob.is_empty()? {
Expand All @@ -303,7 +308,7 @@ pub(crate) fn py_object_to_any_value<'py>(

// This path is only taken if there is no question about the data type.
if dtype.is_primitive() && n_dtypes == 1 {
get_list_with_constructor(ob)
get_list_with_constructor(ob, strict)
} else {
// Push the rest.
let length = list.len()?;
Expand All @@ -325,7 +330,7 @@ pub(crate) fn py_object_to_any_value<'py>(
}
} else {
// range will take this branch
get_list_with_constructor(ob)
get_list_with_constructor(ob, strict)
}
}

Expand Down
149 changes: 149 additions & 0 deletions py-polars/tests/unit/constructors/test_structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import polars as pl


def test_constructor_non_strict_schema_17956() -> None:
schema = {
"logged_event": pl.Struct(
[
pl.Field(
"completetask",
pl.Struct(
[
pl.Field(
"parameters",
pl.List(
pl.Struct(
[
pl.Field(
"numericarray",
pl.Struct(
[
pl.Field(
"value", pl.List(pl.Float64)
),
]
),
),
]
)
),
),
]
),
),
]
),
}

data = {
"logged_event": {
"completetask": {
"parameters": [
{
"numericarray": {
"value": [
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
431,
430.5,
431,
431,
431,
]
}
}
]
}
}
}

lazyframe = pl.LazyFrame(
[data],
schema=schema,
strict=False,
)
assert lazyframe.collect().to_dict(as_series=False) == {
"logged_event": [
{
"completetask": {
"parameters": [
{
"numericarray": {
"value": [
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
431.0,
430.5,
431.0,
431.0,
431.0,
]
}
}
]
}
}
]
}

0 comments on commit ed0f313

Please sign in to comment.