Skip to content

Commit

Permalink
Adding option for absolute val in DE test, deal with small sets (#114)
Browse files Browse the repository at this point in the history
* Adding option for absolute val in DE test, deal with small sets

* add test for small de group
  • Loading branch information
euxhenh authored Dec 15, 2023
1 parent 1706dd5 commit a80a4b3
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/grinch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StorageMixin:
storage : Dict[str, Any]
A dict mapping a key to a representation.
"""
__columns__ = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns']
__columns__ = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns', 'layers']

@property
def prefix(self) -> str:
Expand Down Expand Up @@ -114,7 +114,7 @@ def __insert_prefix_after_col(key: str, prefix: str):
Example
-------
If key='uns.ttest' and prefix='group-0.', then this returns
'uns.group-0.ttest'. Note the dot '.' after 'group-0'.
'uns.group-0.ttest'. Note the dot '.' in prefix.
"""
first_key, store_keys = key.split('.', maxsplit=1)
return f'{first_key}.{prefix}{store_keys}'
Expand Down
11 changes: 11 additions & 0 deletions src/grinch/cond_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Filter(BaseModel, Generic[T]):
A percent fraction betwen 0 and 1. Will round up to the nearest
item.
absolute : bool
If True, will consider the absolute value of `key`.
Examples
--------
>>> f1 = Filter(gt=3)
Expand All @@ -63,6 +66,10 @@ class Filter(BaseModel, Generic[T]):
>>> r = f & f3 # Can also stack StackedFilter and Filter
>>> r([3, 4, 5, 6, 7], as_mask=True)
array([False, False, False, False, False])
>>> fabs = Filter(ge=2, absolute=True)
>>> fabs([-5, -6, -1, 0, 1, 2], as_mask=False)
array([0, 1, 5])
"""
__conditions__ = ['ge', 'le', 'gt', 'lt',
'equal', 'not_equal',
Expand Down Expand Up @@ -92,6 +99,8 @@ class Filter(BaseModel, Generic[T]):
top_ratio: PercentFraction | None = None # top fraction of items
bot_ratio: PercentFraction | None = None # bottom fraction of items

absolute: bool = False

@model_validator(mode='before')
def at_most_one_not_None(cls, data):
"""Ensure that at most one condition is set. If no conditions are
Expand Down Expand Up @@ -246,6 +255,8 @@ def __call__(self, obj, as_mask=True):
obj = self._get_member(obj, self.key)

arr: np.ndarray[T, Any] = column_or_1d(obj)
if self.absolute:
arr = np.abs(arr)

if any_not_None(self.ge, self.gt, self.le, self.lt):
return self._take_cutoff(arr, as_mask)
Expand Down
40 changes: 26 additions & 14 deletions src/grinch/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,36 +77,39 @@ class Config(BaseConfigurable.Config):
----------
attrs_key : str, default=None
The key to store processors attributes in (post fit). Curly
brackets will be formatted. By default use `self.write_key`
followed by an underscore.
brackets will be formatted. E.g., if a Predictor has a key
called `labels_key`, one can set `attrs_key` to
`uns.{labels_key}_`. The value (tail) of `labels_key` will
automatically be parsed into `attrs_key`.
kwargs : dict, default={}
Any Processor parameters that should be passed to the inner
processor object (or related methods).
Any additional processor parameters should be specified under a
`kwargs` dict. Important parameters may be explicitly set in
the Config for convenience.
Class attributes
----------------
__extra_processor_params__ : List[str]
Holds kwargs that will be passed to the underlying
processor/processor function, but are not marked as
ProcessorParams and are also not passed via kwargs.
Kwargs used by the processor, but are not ProcessorParam's or
have a different name than the one specified in the Config.
E.g., `random_state` should be here, as we use `seed`. The only
use for this list is to remove such keys from `kwargs` if any
is present.
"""
if TYPE_CHECKING:
create: Callable[..., 'BaseProcessor']

attrs_key: WriteKey | None = None
kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) # Processor kwargs
kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict)

# Kwargs used by the processor, but are not ProcessorParam's
__extra_processor_params__: List[str] = []

def model_post_init(self, __context):
"""Safely formats attrs key using any field that is a str."""
super().model_post_init(__context)

if self.attrs_key is None:
return

# Use only the tail of other keys to format attrs_key.
field_dict = {
k: v.rsplit('.', 1)[-1] for k, v in self.model_dump().items()
if isinstance(v, str)
Expand All @@ -122,8 +125,8 @@ def remove_explicit_args(cls, val):
cls.__extra_processor_params__):
if val.pop(explicit_key, None) is not None:
logger.warning(
f"Popping '{explicit_key}' from kwargs. "
"This key has been set explicitly."
f"Popping '{explicit_key=}' from kwargs. "
"This key has been set in the Config."
)
return val

Expand Down Expand Up @@ -183,6 +186,16 @@ def __call__(
**kwargs,
):
"""Calls the processor with adata.
Order of operations is as follows:
1) Clear self.storage
2) Run this __call__
a) Index into obs or var if any.
b) Run preprocessing
c) Run processing
d) Run postprocessing
e) Store processor attributes
3) Return storage if kwargs['return_storage'] or write otherwise
"""
if all_not_None(obs_indices, var_indices):
adata = adata[obs_indices, var_indices]
Expand All @@ -195,7 +208,6 @@ def __call__(
self._pre_process(adata)
self._process(adata)
self._post_process(adata)

self.store_attrs()

def _pre_process(self, adata: AnnData) -> None:
Expand Down
18 changes: 17 additions & 1 deletion src/grinch/processors/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd
from anndata import AnnData
from diptest import diptest
from pydantic import Field, PositiveFloat, field_validator
from pydantic import Field, PositiveFloat, PositiveInt, field_validator
from scipy.stats import ks_2samp, ranksums
from sklearn.utils import (
check_array,
Expand Down Expand Up @@ -80,6 +80,10 @@ class Config(BaseProcessor.Config):
These will be replaced with appropriate values (1 for
p-values).
min_points_per_group : int, default=None
If not None, will skip computing values for groups with fewer
than this many points.
control_group : str, default=None
The label to use in a 'one_vs_one' test type. Must be present
in the array specified by `group_key`.
Expand All @@ -105,6 +109,7 @@ class Config(BaseProcessor.Config):
base: PositiveFloat | Literal['e'] | None = Field('e')
correction: str = 'fdr_bh'
replace_nan: bool = True
min_points_per_group: PositiveInt | None = None

# If any of the following is not None, will perform a one_vs_one
# test. E.g., if control samples are given in `control_key`, then for
Expand Down Expand Up @@ -272,6 +277,9 @@ def _single_test(
"""Perform a single t-Test.
"""
n1, m1, v1 = pmv.compute([label], ddof=1) # Stats for label
if self.cfg.min_points_per_group is not None:
if n1 < self.cfg.min_points_per_group:
return pd.DataFrame() # Empty
# If no control group, will compute from all but label (one-vs-all)
n2, m2, v2 = control_stats or pmv.compute([label], ddof=1, exclude=True)

Expand Down Expand Up @@ -325,6 +333,10 @@ def _test(

def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2=None) -> pd.DataFrame:
"""Perform a single rank sum test."""
if self.cfg.min_points_per_group is not None:
if x.shape[0] < self.cfg.min_points_per_group:
return pd.DataFrame()

statistic, pvals = ranksums(x, y, alternative=self.cfg.alternative)
pvals, qvals = self.get_pqvals(pvals)
m1 = pmv.compute([label], ddof=1)[1] # take label
Expand Down Expand Up @@ -383,6 +395,10 @@ def _test(

def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2) -> pd.DataFrame:
"""Perform a single ks test"""
if self.cfg.min_points_per_group is not None:
if x.shape[0] < self.cfg.min_points_per_group:
return pd.DataFrame()

part_ks_2samp = partial(
ks_2samp,
alternative=self.cfg.alternative,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ def test_tests(X, test, key):
assert log2fc[4] < 2


@pytest.mark.parametrize("test,key", tests)
def test_small_group(test, key):
X = np.array([
[1, 2, 3, 100, 150],
[60, 46, 34, 0, 0],
[50, 49, 34, 0, 0],
[60, 46, 38, 0, 0],
])
cfg = OmegaConf.create(
{
"_target_": f"src.grinch.{test}.Config",
"min_points_per_group": 2,
"group_key": "obs.label",
}
)
cfg = instantiate(cfg)
test = cfg.create()
adata = AnnData(X)
adata.obs['label'] = [0, 1, 1, 1]
test(adata)
# empty dataframe
assert len(adata.uns[key]['label-0']) == 0


X = np.array([
[1, 5, 4, 45, 62],
[5, 2, 4, 44, 75],
Expand Down

0 comments on commit a80a4b3

Please sign in to comment.