From 0aa32ae3b5b84fa61983a29c5f1a7ec788465483 Mon Sep 17 00:00:00 2001 From: Andy Date: Mon, 9 Oct 2023 16:26:30 +0200 Subject: [PATCH] feat(rust, python): Implement `schema`, `schema_override` for `pl.read_json` with array-like input (#11492) --- crates/polars-io/src/json/mod.rs | 55 ++++++++++++++++------------ py-polars/tests/unit/io/test_json.py | 49 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 24 deletions(-) diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index 47d30be41246..16c03d8a5349 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -215,14 +215,13 @@ where let mut_schema = Arc::make_mut(&mut schema); overwrite_schema(mut_schema, overwrite)?; } + DataType::Struct(schema.iter_fields().collect()).to_arrow() } else { // infer - if let BorrowedValue::Array(values) = &json_value { - polars_ensure!(self.schema_overwrite.is_none() && self.schema.is_none(), ComputeError: "schema arguments not yet supported for Array json"); - + let inner_dtype = if let BorrowedValue::Array(values) = &json_value { // struct types may have missing fields so find supertype - let dtype = values + values .iter() .take(self.infer_schema_len.unwrap_or(usize::MAX)) .map(|value| { @@ -235,32 +234,40 @@ where let r = r?; try_get_supertype(&l, &r) }) - .unwrap()?; - let dtype = DataType::List(Box::new(dtype)); - dtype.to_arrow() + .unwrap()? + .to_arrow() } else { - let dtype = infer(&json_value)?; - if let Some(overwrite) = self.schema_overwrite { - let ArrowDataType::Struct(fields) = dtype else { - polars_bail!(ComputeError: "can only deserialize json objects") - }; + infer(&json_value)? + }; - let mut schema = Schema::from_iter(fields.iter()); - overwrite_schema(&mut schema, overwrite)?; + if let Some(overwrite) = self.schema_overwrite { + let ArrowDataType::Struct(fields) = inner_dtype else { + polars_bail!(ComputeError: "can only deserialize json objects") + }; - DataType::Struct( - schema - .into_iter() - .map(|(name, dt)| Field::new(&name, dt)) - .collect(), - ) - .to_arrow() - } else { - dtype - } + let mut schema = Schema::from_iter(fields.iter()); + overwrite_schema(&mut schema, overwrite)?; + + DataType::Struct( + schema + .into_iter() + .map(|(name, dt)| Field::new(&name, dt)) + .collect(), + ) + .to_arrow() + } else { + inner_dtype } }; + let dtype = if let BorrowedValue::Array(_) = &json_value { + ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new( + "item", dtype, true, + ))) + } else { + dtype + }; + let arr = polars_json::json::deserialize(&json_value, dtype)?; let arr = arr.as_any().downcast_ref::().ok_or_else( || polars_err!(ComputeError: "can only deserialize json objects"), diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index 2a347fa174ca..faa15208aa70 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -32,6 +32,55 @@ def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: assert_frame_equal(df, out, categorical_as_str=True) +def test_to_from_buffer_arraywise_schema() -> None: + buf = io.StringIO( + """ + [ + {"a": 5, "b": "foo", "c": null}, + {"a": 11.4, "b": null, "c": true, "d": 8}, + {"a": -25.8, "b": "bar", "c": false} + ]""" + ) + + read_df = pl.read_json(buf, schema={"b": pl.Utf8, "e": pl.Int16}) + + assert_frame_equal( + read_df, + pl.DataFrame( + { + "b": pl.Series(["foo", None, "bar"], dtype=pl.Utf8), + "e": pl.Series([None, None, None], dtype=pl.Int16), + } + ), + ) + + +def test_to_from_buffer_arraywise_schema_override() -> None: + buf = io.StringIO( + """ + [ + {"a": 5, "b": "foo", "c": null}, + {"a": 11.4, "b": null, "c": true, "d": 8}, + {"a": -25.8, "b": "bar", "c": false} + ]""" + ) + + read_df = pl.read_json(buf, schema_overrides={"c": pl.Int64, "d": pl.Float64}) + + assert_frame_equal( + read_df, + pl.DataFrame( + { + "a": pl.Series([5, 11.4, -25.8], dtype=pl.Float64), + "b": pl.Series(["foo", None, "bar"], dtype=pl.Utf8), + "c": pl.Series([None, 1, 0], dtype=pl.Int64), + "d": pl.Series([None, 8, None], dtype=pl.Float64), + } + ), + check_column_order=False, + ) + + def test_write_json_to_string() -> None: # Tests if it runs if no arg given df = pl.DataFrame({"a": [1, 2, 3]})