Skip to content

Commit

Permalink
Merge fix from #387
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Sep 13, 2024
1 parent 34c4789 commit 1d58498
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def _extract_codes_and_categories(
indices = pd_vec.cat.codes.to_numpy()
elif namespace.__name__ == "polars":
pl_vec: pl.Series = nw.to_native(cat_vec)
categories = pl_vec.cat.get_categories().to_numpy()
indices = pl_vec.to_physical().fill_null(-1).to_numpy()
categories = pl_vec.cat.to_local().cat.get_categories().to_numpy()
indices = pl_vec.cat.to_local().to_physical().fill_null(-1).to_numpy()

return indices, categories, namespace

Expand Down
2 changes: 1 addition & 1 deletion tests/test_categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,5 +271,5 @@ def test_polars_non_contiguous_codes():
_ = pl.Series(["beagle", "poodle", "labrador"], dtype=pl.Categorical)
cat_series = pl.Series(str_series, dtype=pl.Categorical)

indices, categories = _extract_codes_and_categories(cat_series)
indices, categories, _ = _extract_codes_and_categories(cat_series)
np.testing.assert_array_equal(str_series, categories[indices].tolist())

0 comments on commit 1d58498

Please sign in to comment.