Skip to content

Commit

Permalink
feat: More dtypes supports cast to list
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 11, 2023
1 parent 0f37edf commit 7557911
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ xxhash-rust = { version = "0.8.6", features = ["xxh3"] }
[workspace.dependencies.arrow]
package = "arrow2"
git = "https://github.com/jorgecarleitao/arrow2"
rev = "ba6a882bc1542b0b899774b696ebea77482b5c31"
rev = "7c93e358fc400bf3c0c0219c22eefc6b38fc2d12"
# branch = ""
# version = "0.17.4"
default-features = false
Expand Down
27 changes: 20 additions & 7 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from datetime import date, datetime, time
from typing import TYPE_CHECKING, Any

import pandas as pd
import pytest

import polars as pl
from polars.testing import assert_series_equal

if TYPE_CHECKING:
from polars import PolarsDataType


def test_dtype() -> None:
# inferred
Expand Down Expand Up @@ -460,13 +464,22 @@ def test_list_recursive_categorical_cast() -> None:
assert s.to_list() == values


def test_non_nested_cast_to_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})

df = df.with_columns([pl.col("a").cast(pl.List(pl.Int64))])

expected = pl.Series("a", [[1], [2], [3]])
assert_series_equal(df.to_series(), expected)
@pytest.mark.parametrize(
("data", "expected_data", "dtype"),
[
([1, 2], [[1], [2]], pl.Int64),
([1.0, 2.0], [[1.0], [2.0]], pl.Float64),
(["x", "y"], [["x"], ["y"]], pl.Utf8),
([True, False], [[True], [False]], pl.Boolean),
],
)
def test_non_nested_cast_to_list(
data: list[Any], expected_data: list[Any], dtype: PolarsDataType
) -> None:
s = pl.Series(data, dtype=dtype)
casted_s = s.cast(pl.List(dtype))
expected = pl.Series(expected_data, dtype=pl.List(dtype))
assert_series_equal(casted_s, expected)


def test_list_new_from_index_logical() -> None:
Expand Down

0 comments on commit 7557911

Please sign in to comment.