Skip to content

Commit

Permalink
feat(rust, python): Implement schema, schema_override for `pl.rea…
Browse files Browse the repository at this point in the history
…d_json` with array-like input (#11492)
  • Loading branch information
andysham authored Oct 9, 2023
1 parent 5eb499c commit 0aa32ae
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 24 deletions.
55 changes: 31 additions & 24 deletions crates/polars-io/src/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,13 @@ where
let mut_schema = Arc::make_mut(&mut schema);
overwrite_schema(mut_schema, overwrite)?;
}

DataType::Struct(schema.iter_fields().collect()).to_arrow()
} else {
// infer
if let BorrowedValue::Array(values) = &json_value {
polars_ensure!(self.schema_overwrite.is_none() && self.schema.is_none(), ComputeError: "schema arguments not yet supported for Array json");

let inner_dtype = if let BorrowedValue::Array(values) = &json_value {
// struct types may have missing fields so find supertype
let dtype = values
values
.iter()
.take(self.infer_schema_len.unwrap_or(usize::MAX))
.map(|value| {
Expand All @@ -235,32 +234,40 @@ where
let r = r?;
try_get_supertype(&l, &r)
})
.unwrap()?;
let dtype = DataType::List(Box::new(dtype));
dtype.to_arrow()
.unwrap()?
.to_arrow()
} else {
let dtype = infer(&json_value)?;
if let Some(overwrite) = self.schema_overwrite {
let ArrowDataType::Struct(fields) = dtype else {
polars_bail!(ComputeError: "can only deserialize json objects")
};
infer(&json_value)?
};

let mut schema = Schema::from_iter(fields.iter());
overwrite_schema(&mut schema, overwrite)?;
if let Some(overwrite) = self.schema_overwrite {
let ArrowDataType::Struct(fields) = inner_dtype else {
polars_bail!(ComputeError: "can only deserialize json objects")
};

DataType::Struct(
schema
.into_iter()
.map(|(name, dt)| Field::new(&name, dt))
.collect(),
)
.to_arrow()
} else {
dtype
}
let mut schema = Schema::from_iter(fields.iter());
overwrite_schema(&mut schema, overwrite)?;

DataType::Struct(
schema
.into_iter()
.map(|(name, dt)| Field::new(&name, dt))
.collect(),
)
.to_arrow()
} else {
inner_dtype
}
};

let dtype = if let BorrowedValue::Array(_) = &json_value {
ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new(
"item", dtype, true,
)))
} else {
dtype
};

let arr = polars_json::json::deserialize(&json_value, dtype)?;
let arr = arr.as_any().downcast_ref::<StructArray>().ok_or_else(
|| polars_err!(ComputeError: "can only deserialize json objects"),
Expand Down
49 changes: 49 additions & 0 deletions py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,55 @@ def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
assert_frame_equal(df, out, categorical_as_str=True)


def test_to_from_buffer_arraywise_schema() -> None:
buf = io.StringIO(
"""
[
{"a": 5, "b": "foo", "c": null},
{"a": 11.4, "b": null, "c": true, "d": 8},
{"a": -25.8, "b": "bar", "c": false}
]"""
)

read_df = pl.read_json(buf, schema={"b": pl.Utf8, "e": pl.Int16})

assert_frame_equal(
read_df,
pl.DataFrame(
{
"b": pl.Series(["foo", None, "bar"], dtype=pl.Utf8),
"e": pl.Series([None, None, None], dtype=pl.Int16),
}
),
)


def test_to_from_buffer_arraywise_schema_override() -> None:
buf = io.StringIO(
"""
[
{"a": 5, "b": "foo", "c": null},
{"a": 11.4, "b": null, "c": true, "d": 8},
{"a": -25.8, "b": "bar", "c": false}
]"""
)

read_df = pl.read_json(buf, schema_overrides={"c": pl.Int64, "d": pl.Float64})

assert_frame_equal(
read_df,
pl.DataFrame(
{
"a": pl.Series([5, 11.4, -25.8], dtype=pl.Float64),
"b": pl.Series(["foo", None, "bar"], dtype=pl.Utf8),
"c": pl.Series([None, 1, 0], dtype=pl.Int64),
"d": pl.Series([None, 8, None], dtype=pl.Float64),
}
),
check_column_order=False,
)


def test_write_json_to_string() -> None:
# Tests if it runs if no arg given
df = pl.DataFrame({"a": [1, 2, 3]})
Expand Down

0 comments on commit 0aa32ae

Please sign in to comment.