From a80a4b3cd89ad3dcb1c6ed40fe5b78c9bd92d8c8 Mon Sep 17 00:00:00 2001 From: Euxhen Hasanaj Date: Fri, 15 Dec 2023 01:57:21 -0500 Subject: [PATCH] Adding option for absolute val in DE test, deal with small sets (#114) * Adding option for absolute val in DE test, deal with small sets * add test for small de group --- src/grinch/base.py | 4 +-- src/grinch/cond_filter.py | 11 +++++++ src/grinch/processors/base_processor.py | 40 ++++++++++++++++--------- src/grinch/processors/de.py | 18 ++++++++++- tests/test_de.py | 24 +++++++++++++++ 5 files changed, 80 insertions(+), 17 deletions(-) diff --git a/src/grinch/base.py b/src/grinch/base.py index 8b798cf..2177a6d 100644 --- a/src/grinch/base.py +++ b/src/grinch/base.py @@ -48,7 +48,7 @@ class StorageMixin: storage : Dict[str, Any] A dict mapping a key to a representation. """ - __columns__ = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns'] + __columns__ = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns', 'layers'] @property def prefix(self) -> str: @@ -114,7 +114,7 @@ def __insert_prefix_after_col(key: str, prefix: str): Example ------- If key='uns.ttest' and prefix='group-0.', then this returns - 'uns.group-0.ttest'. Note the dot '.' after 'group-0'. + 'uns.group-0.ttest'. Note the dot '.' in prefix. """ first_key, store_keys = key.split('.', maxsplit=1) return f'{first_key}.{prefix}{store_keys}' diff --git a/src/grinch/cond_filter.py b/src/grinch/cond_filter.py index c1def90..0639be1 100644 --- a/src/grinch/cond_filter.py +++ b/src/grinch/cond_filter.py @@ -40,6 +40,9 @@ class Filter(BaseModel, Generic[T]): A percent fraction betwen 0 and 1. Will round up to the nearest item. + absolute : bool + If True, will consider the absolute value of `key`. + Examples -------- >>> f1 = Filter(gt=3) @@ -63,6 +66,10 @@ class Filter(BaseModel, Generic[T]): >>> r = f & f3 # Can also stack StackedFilter and Filter >>> r([3, 4, 5, 6, 7], as_mask=True) array([False, False, False, False, False]) + + >>> fabs = Filter(ge=2, absolute=True) + >>> fabs([-5, -6, -1, 0, 1, 2], as_mask=False) + array([0, 1, 5]) """ __conditions__ = ['ge', 'le', 'gt', 'lt', 'equal', 'not_equal', @@ -92,6 +99,8 @@ class Filter(BaseModel, Generic[T]): top_ratio: PercentFraction | None = None # top fraction of items bot_ratio: PercentFraction | None = None # bottom fraction of items + absolute: bool = False + @model_validator(mode='before') def at_most_one_not_None(cls, data): """Ensure that at most one condition is set. If no conditions are @@ -246,6 +255,8 @@ def __call__(self, obj, as_mask=True): obj = self._get_member(obj, self.key) arr: np.ndarray[T, Any] = column_or_1d(obj) + if self.absolute: + arr = np.abs(arr) if any_not_None(self.ge, self.gt, self.le, self.lt): return self._take_cutoff(arr, as_mask) diff --git a/src/grinch/processors/base_processor.py b/src/grinch/processors/base_processor.py index 77f4c9a..41be599 100644 --- a/src/grinch/processors/base_processor.py +++ b/src/grinch/processors/base_processor.py @@ -77,36 +77,39 @@ class Config(BaseConfigurable.Config): ---------- attrs_key : str, default=None The key to store processors attributes in (post fit). Curly - brackets will be formatted. By default use `self.write_key` - followed by an underscore. + brackets will be formatted. E.g., if a Predictor has a key + called `labels_key`, one can set `attrs_key` to + `uns.{labels_key}_`. The value (tail) of `labels_key` will + automatically be parsed into `attrs_key`. kwargs : dict, default={} - Any Processor parameters that should be passed to the inner - processor object (or related methods). + Any additional processor parameters should be specified under a + `kwargs` dict. Important parameters may be explicitly set in + the Config for convenience. Class attributes ---------------- __extra_processor_params__ : List[str] - Holds kwargs that will be passed to the underlying - processor/processor function, but are not marked as - ProcessorParams and are also not passed via kwargs. + Kwargs used by the processor, but are not ProcessorParam's or + have a different name than the one specified in the Config. + E.g., `random_state` should be here, as we use `seed`. The only + use for this list is to remove such keys from `kwargs` if any + is present. """ if TYPE_CHECKING: create: Callable[..., 'BaseProcessor'] attrs_key: WriteKey | None = None - kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) # Processor kwargs + kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) - # Kwargs used by the processor, but are not ProcessorParam's __extra_processor_params__: List[str] = [] def model_post_init(self, __context): """Safely formats attrs key using any field that is a str.""" super().model_post_init(__context) - if self.attrs_key is None: return - + # Use only the tail of other keys to format attrs_key. field_dict = { k: v.rsplit('.', 1)[-1] for k, v in self.model_dump().items() if isinstance(v, str) @@ -122,8 +125,8 @@ def remove_explicit_args(cls, val): cls.__extra_processor_params__): if val.pop(explicit_key, None) is not None: logger.warning( - f"Popping '{explicit_key}' from kwargs. " - "This key has been set explicitly." + f"Popping '{explicit_key=}' from kwargs. " + "This key has been set in the Config." ) return val @@ -183,6 +186,16 @@ def __call__( **kwargs, ): """Calls the processor with adata. + + Order of operations is as follows: + 1) Clear self.storage + 2) Run this __call__ + a) Index into obs or var if any. + b) Run preprocessing + c) Run processing + d) Run postprocessing + e) Store processor attributes + 3) Return storage if kwargs['return_storage'] or write otherwise """ if all_not_None(obs_indices, var_indices): adata = adata[obs_indices, var_indices] @@ -195,7 +208,6 @@ def __call__( self._pre_process(adata) self._process(adata) self._post_process(adata) - self.store_attrs() def _pre_process(self, adata: AnnData) -> None: diff --git a/src/grinch/processors/de.py b/src/grinch/processors/de.py index 33bfcbe..f1e24d3 100644 --- a/src/grinch/processors/de.py +++ b/src/grinch/processors/de.py @@ -11,7 +11,7 @@ import pandas as pd from anndata import AnnData from diptest import diptest -from pydantic import Field, PositiveFloat, field_validator +from pydantic import Field, PositiveFloat, PositiveInt, field_validator from scipy.stats import ks_2samp, ranksums from sklearn.utils import ( check_array, @@ -80,6 +80,10 @@ class Config(BaseProcessor.Config): These will be replaced with appropriate values (1 for p-values). + min_points_per_group : int, default=None + If not None, will skip computing values for groups with fewer + than this many points. + control_group : str, default=None The label to use in a 'one_vs_one' test type. Must be present in the array specified by `group_key`. @@ -105,6 +109,7 @@ class Config(BaseProcessor.Config): base: PositiveFloat | Literal['e'] | None = Field('e') correction: str = 'fdr_bh' replace_nan: bool = True + min_points_per_group: PositiveInt | None = None # If any of the following is not None, will perform a one_vs_one # test. E.g., if control samples are given in `control_key`, then for @@ -272,6 +277,9 @@ def _single_test( """Perform a single t-Test. """ n1, m1, v1 = pmv.compute([label], ddof=1) # Stats for label + if self.cfg.min_points_per_group is not None: + if n1 < self.cfg.min_points_per_group: + return pd.DataFrame() # Empty # If no control group, will compute from all but label (one-vs-all) n2, m2, v2 = control_stats or pmv.compute([label], ddof=1, exclude=True) @@ -325,6 +333,10 @@ def _test( def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2=None) -> pd.DataFrame: """Perform a single rank sum test.""" + if self.cfg.min_points_per_group is not None: + if x.shape[0] < self.cfg.min_points_per_group: + return pd.DataFrame() + statistic, pvals = ranksums(x, y, alternative=self.cfg.alternative) pvals, qvals = self.get_pqvals(pvals) m1 = pmv.compute([label], ddof=1)[1] # take label @@ -383,6 +395,10 @@ def _test( def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2) -> pd.DataFrame: """Perform a single ks test""" + if self.cfg.min_points_per_group is not None: + if x.shape[0] < self.cfg.min_points_per_group: + return pd.DataFrame() + part_ks_2samp = partial( ks_2samp, alternative=self.cfg.alternative, diff --git a/tests/test_de.py b/tests/test_de.py index db57339..2f1a699 100644 --- a/tests/test_de.py +++ b/tests/test_de.py @@ -79,6 +79,30 @@ def test_tests(X, test, key): assert log2fc[4] < 2 +@pytest.mark.parametrize("test,key", tests) +def test_small_group(test, key): + X = np.array([ + [1, 2, 3, 100, 150], + [60, 46, 34, 0, 0], + [50, 49, 34, 0, 0], + [60, 46, 38, 0, 0], + ]) + cfg = OmegaConf.create( + { + "_target_": f"src.grinch.{test}.Config", + "min_points_per_group": 2, + "group_key": "obs.label", + } + ) + cfg = instantiate(cfg) + test = cfg.create() + adata = AnnData(X) + adata.obs['label'] = [0, 1, 1, 1] + test(adata) + # empty dataframe + assert len(adata.uns[key]['label-0']) == 0 + + X = np.array([ [1, 5, 4, 45, 62], [5, 2, 4, 44, 75],