Skip to content

Commit

Permalink
fix(python): Correctly parse special float values in from_repr (#20351
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie authored Dec 19, 2024
1 parent c56ecdf commit 5fc9a19
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
9 changes: 7 additions & 2 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def _cast_repr_strings_with_schema(
msg = f"DataFrame should contain only String repr data; found {tp!r}"
raise TypeError(msg)

special_floats = {"-inf", "+inf", "inf", "nan"}

# duration string scaling
ns_sec = 1_000_000_000
duration_scaling = {
Expand Down Expand Up @@ -366,8 +368,11 @@ def str_duration_(td: str | None) -> int | None:
integer_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$1")
fractional_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$2")
cast_cols[c] = (
# check for empty string and/or integer format
pl.when(F.col(c).str.contains(r"^[+-]?\d*$"))
# check for empty string, special floats, or integer format
pl.when(
F.col(c).str.contains(r"^[+-]?\d*$")
| F.col(c).str.to_lowercase().is_in(special_floats)
)
.then(pl.when(F.col(c).str.len_bytes() > 0).then(F.col(c)))
# check for scientific notation
.when(F.col(c).str.contains("[eE]"))
Expand Down
19 changes: 16 additions & 3 deletions py-polars/polars/convert/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,12 @@ def _from_dataframe_repr(m: re.Match[str]) -> DataFrame:
buf = io.BytesIO()
df.write_csv(file=buf)
buf.seek(0)
df = read_csv(buf, new_columns=df.columns, try_parse_dates=True)
df = read_csv(
buf,
new_columns=df.columns,
try_parse_dates=True,
infer_schema_length=None,
)
return df
elif schema and not data:
return df.cast(schema) # type: ignore[arg-type]
Expand All @@ -786,8 +791,16 @@ def _from_series_repr(m: re.Match[str]) -> Series:
]
if string_values == ["[", "]"]:
string_values = []
elif string_values and string_values[0].lstrip("#> ") == "[":
string_values = string_values[1:]
else:
start: int | None = None
end: int | None = None
for idx, v in enumerate(string_values):
if start is None and v.lstrip("#> ") == "[":
start = idx
if v.lstrip("#> ") == "]":
end = idx
if start is not None and end is not None:
string_values = string_values[start + 1 : end]

values = string_values[:length] if length > 0 else string_values
values = [(None if v == "null" else v) for v in values if v not in ("…", "...")]
Expand Down
27 changes: 27 additions & 0 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,33 @@ def test_series_from_repr() -> None:
)
assert_series_equal(s, pl.Series("flt", [], dtype=pl.Float32))

s = cast(
pl.Series,
pl.from_repr(
"""
Series: 'flt' [f64]
[
null
+inf
-inf
inf
0.0
NaN
]
>>> print("stuff")
"""
),
)
inf, nan = float("inf"), float("nan")
assert_series_equal(
s,
pl.Series(
name="flt",
dtype=pl.Float64,
values=[None, inf, -inf, inf, 0.0, nan],
),
)


def test_dataframe_from_repr_custom_separators() -> None:
# repr created with custom digit-grouping
Expand Down

0 comments on commit 5fc9a19

Please sign in to comment.