diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 83c8aceb..331695b1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,13 @@ Changelog ========= +Unreleased +---------- + +**Other changes:** + +- Improve the performance of ``from_pandas`` in the case of low-cardinality categorical variables. + 3.1.10 - 2023-06-23 ------------------- diff --git a/src/tabmat/constructor.py b/src/tabmat/constructor.py index b782f2da..f8e23c31 100644 --- a/src/tabmat/constructor.py +++ b/src/tabmat/constructor.py @@ -72,6 +72,7 @@ def from_pandas( if object_as_cat and coldata.dtype == object: coldata = coldata.astype("category") if isinstance(coldata.dtype, pd.CategoricalDtype): + cat = CategoricalMatrix(coldata, drop_first=drop_first, dtype=dtype) if len(coldata.cat.categories) < cat_threshold: ( X_dense_F, @@ -79,15 +80,7 @@ def from_pandas( dense_indices, sparse_indices, ) = _split_sparse_and_dense_parts( - pd.get_dummies( - coldata, - prefix=colname, - sparse=True, - drop_first=drop_first, - dtype=np.float64, - ) - .sparse.to_coo() - .tocsc(), + sps.csc_matrix(cat.tocsr(), dtype=dtype), threshold=sparse_threshold, ) matrices.append(X_dense_F) @@ -103,7 +96,6 @@ def from_pandas( indices.append(sparse_indices) else: - cat = CategoricalMatrix(coldata, drop_first=drop_first, dtype=dtype) matrices.append(cat) is_cat.append(True) if cat_position == "expand":