Skip to content

Commit

Permalink
Update Mark HVGS (#133)
Browse files Browse the repository at this point in the history
* Separated summary stat from mark_hvgs. Checked type issues.

* Added check for lowess_frac between 0 and 1
  • Loading branch information
Gautam8387 authored Dec 17, 2024
1 parent fd40926 commit 4e93e14
Showing 1 changed file with 66 additions and 23 deletions.
89 changes: 66 additions & 23 deletions scarf/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
method for feature selection.
"""

from typing import Tuple, List, Generator, Optional, Union
from typing import Generator, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import zarr
from dask.array.core import Array as daskArrayType
from dask.array.core import from_zarr
from scipy.sparse import csr_matrix, vstack
from zarr import hierarchy as z_hierarchy

from .metadata import MetaData
from .utils import show_dask_progress, controlled_compute, logger
from .utils import controlled_compute, logger, show_dask_progress

zarrGroup = z_hierarchy.Group

Expand Down Expand Up @@ -279,12 +280,18 @@ def _verify_keys(self, cell_key: str, feat_key: str) -> None:
feat_key: Name of the key (column) from feature attribute table
Returns: None
Note on type checking /GA:
1. ds.cells.get_dtype(cell_key) == bool returns True because dtype('bool') (from numpy) is conceptually equivalent to Python's bool.
2. isinstance(ds.cells.get_dtype(cell_key), bool) returns False because dtype('bool') is a numpy.dtype object, not the native Python bool type.
3. Reason: dtype('bool') is a numpy object, and isinstance checks for the exact class, which is numpy.dtype, not bool.
"""
if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool:
if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: # noqa: E721
raise ValueError(
f"ERROR: Either {cell_key} does not exist or is not bool type"
)
if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool:
if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: # noqa: E721
raise ValueError(
f"ERROR: Either {feat_key} does not exist or is not bool type"
)
Expand Down Expand Up @@ -526,9 +533,10 @@ def iter_normed_feature_wise(
columns=feat_idx[chunk],
)
else:
yield controlled_compute(data[:, chunk], self.nthreads).T, feat_idx[
chunk
]
yield (
controlled_compute(data[:, chunk], self.nthreads).T,
feat_idx[chunk],
)

def save_normed_for_query(
self, feat_key: Optional[str], batch_size: int, overwrite: bool = True
Expand All @@ -549,6 +557,7 @@ def save_normed_for_query(
None
"""
from joblib import Parallel, delayed

from .writers import create_zarr_obj_array

def write_wrapper(idx: str, v: np.ndarray) -> None:
Expand All @@ -563,7 +572,8 @@ def write_wrapper(idx: str, v: np.ndarray) -> None:
None, feat_key, batch_size, "Saving features", False
):
Parallel(n_jobs=self.nthreads)(
delayed(write_wrapper)(inds[i], mat[i]) for i in range(len(inds)) # type: ignore
delayed(write_wrapper)(inds[i], mat[i])
for i in range(len(inds)) # type: ignore
)

def save_aggregated_ordering(
Expand Down Expand Up @@ -888,6 +898,51 @@ def set_feature_stats(self, cell_key: str) -> None:
self.feats.unmount_location(identifier)
return None

def set_summary_stats(
self, cell_key: str = None, n_bins: int = 200, lowess_frac: float = 0.1
) -> Tuple[str, str]:
"""Calculates summary statistics for the features of the assay using only cells that are marked True by the 'cell_key' parameter.
Args:
cell_key: Name of the key (column) from cell attribute table.
n_bins: Number of bins to divide the data into.
lowess_frac: Between 0 and 1. The fraction of the data used when estimating the fit between mean and
variance. This is same as `frac` in statsmodels.nonparametric.smoothers_lowess.lowess
Returns:
A tuple of two strings.
identifier: The text that will be prepended to column names when summary statistics are loaded onto the feature attributes table.
c_var_col: The name of the column in the feature attribute table that contains the corrected variance values.
"""

def col_renamer(x):
return f"{identifier}_{x}"

if cell_key is None:
cell_key = "I"

# check lowess_frac is between 0 and 1
if not 0 <= lowess_frac <= 1:
raise ValueError("lowess_frac must be between 0 and 1")

self.set_feature_stats(cell_key)
identifier = self._load_stats_loc(cell_key)
c_var_col = f"c_var__{n_bins}__{lowess_frac}"
if col_renamer(c_var_col) in self.feats.columns:
logger.info("Using existing corrected dispersion values")
else:
slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"]
for i in slots:
i = col_renamer(i)
if i not in self.feats.columns:
raise KeyError(f"ERROR: {i} not found in feature metadata")
c_var = self.feats.remove_trend(
col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac
)
self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier)

return identifier, c_var_col

# maybe we should return plot here? If one wants to modify it. /raz
def mark_hvgs(
self,
Expand Down Expand Up @@ -950,21 +1005,9 @@ def mark_hvgs(
def col_renamer(x):
return f"{identifier}_{x}"

self.set_feature_stats(cell_key)
identifier = self._load_stats_loc(cell_key)
c_var_col = f"c_var__{n_bins}__{lowess_frac}"
if col_renamer(c_var_col) in self.feats.columns:
logger.info("Using existing corrected dispersion values")
else:
slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"]
for i in slots:
i = col_renamer(i)
if i not in self.feats.columns:
raise KeyError(f"ERROR: {i} not found in feature metadata")
c_var = self.feats.remove_trend(
col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac
)
self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier)
logger.info("Calculating summary statistics")
identifier, c_var_col = self.set_summary_stats(cell_key, n_bins, lowess_frac)
logger.info("Calculating HVGs")

if max_mean != np.inf:
max_mean = 2**max_mean
Expand Down

0 comments on commit 4e93e14

Please sign in to comment.