diff --git a/decoupler/utils_anndata.py b/decoupler/utils_anndata.py index c759d6a..7979682 100644 --- a/decoupler/utils_anndata.py +++ b/decoupler/utils_anndata.py @@ -160,7 +160,7 @@ def format_psbulk_inputs(sample_col, groups_col, obs): if groups_col is None: # Filter extra columns in obs - cols = obs.groupby(sample_col, observed=True).nunique().eq(1).all(0) + cols = obs.groupby(sample_col, observed=True).nunique(dropna=False).eq(1).all(0) cols = np.hstack([sample_col, cols[cols].index]) obs = obs.loc[:, cols] @@ -179,7 +179,7 @@ def format_psbulk_inputs(sample_col, groups_col, obs): groups_col = joined_cols # Filter extra columns in obs - cols = obs.groupby([sample_col, groups_col], observed=True).nunique().eq(1).all(0) + cols = obs.groupby([sample_col, groups_col], observed=True).nunique(dropna=False).eq(1).all(0) cols = np.hstack([sample_col, groups_col, cols[cols].index]) obs = obs.loc[:, cols]