Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise PC regression #408

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Test with pytest
if: ${{ matrix.os != 'macos-latest'}}
run: |
pytest --cov=scib --cov-report=xml -vv --ignore=tests/integration/ --ignore=tests/metrics/rpy2 -vv
pytest --cov=scib --cov-report=xml -vv --ignore=tests/integration/ --ignore=tests/metrics/rpy2 -vv --durations 0 --durations-min=1.0
mv coverage.xml "$(echo 'coverage_metrics_${{ matrix.os }}_${{ matrix.python }}.xml' | sed 's/[^a-z0-9\.\/]/_/g')"

- name: Upload coverage to GitHub Actions
Expand Down Expand Up @@ -98,7 +98,7 @@ jobs:

- name: Test with pytest
run: |
pytest --cov=scib --cov-report=xml -vv --tb=native -k rpy2
pytest --cov=scib --cov-report=xml -vv --tb=native -k rpy2 --durations 0 --durations-min=1.0
mv coverage.xml "$(echo 'coverage_rpy2_${{ matrix.os }}_${{ matrix.python }}.xml' | sed 's/[^a-z0-9\.\/]/_/g')"

- name: Upload coverage to GitHub Actions
Expand Down Expand Up @@ -129,7 +129,7 @@ jobs:

- name: Test with pytest
run: |
pytest --cov=scib --cov-report=xml -vv --tb=native -k integration
pytest --cov=scib --cov-report=xml -vv --tb=native -k integration --durations 0 --durations-min=1.0
mv coverage.xml "$(echo 'coverage_integration_${{ matrix.os }}_${{ matrix.python }}.xml' | sed 's/[^a-z0-9\.\/]/_/g')"

- name: Upload coverage to GitHub Actions
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ build-backend = "setuptools.build_meta"
log_cli = 'True'
log_cli_level = 'INFO'
addopts = '-p no:warnings'
durations = 0
114 changes: 95 additions & 19 deletions scib/metrics/pcr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

from ..utils import check_adata, check_batch


def pcr_comparison(
adata_pre, adata_post, covariate, embed=None, n_comps=50, scale=True, verbose=False
adata_pre,
adata_post,
covariate,
embed=None,
n_comps=50,
linreg_method="sklearn",
recompute_pca=False,
scale=True,
verbose=False,
n_threads=1,
):
"""Principal component regression score

Expand All @@ -25,6 +36,8 @@ def pcr_comparison(
If None, use the full expression matrix (``adata_post.X``), otherwise use the embedding
provided in ``adata_post.obsm[embed]``.
:param n_comps: Number of principal components to compute
:param recompute_pca: Whether to recompute PCA with default settings
:param linreg_method: Method for linear regression, either 'sklearn' or 'numpy'
:param scale: If True, scale score between 0 and 1 (default)
:param verbose:
:return:
Expand Down Expand Up @@ -52,17 +65,21 @@ def pcr_comparison(
pcr_before = pcr(
adata_pre,
covariate=covariate,
recompute_pca=True,
recompute_pca=recompute_pca,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
verbose=verbose,
)

pcr_after = pcr(
adata_post,
covariate=covariate,
embed=embed,
recompute_pca=True,
recompute_pca=recompute_pca,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
verbose=verbose,
)

Expand All @@ -79,7 +96,16 @@ def pcr_comparison(
return pcr_after - pcr_before


def pcr(adata, covariate, embed=None, n_comps=50, recompute_pca=True, verbose=False):
def pcr(
adata,
covariate,
embed=None,
n_comps=50,
recompute_pca=False,
linreg_method="sklearn",
verbose=False,
n_threads=1,
):
"""Principal component regression for anndata object

Wraps :func:`~scib.metrics.pc_regression` while checking whether to:
Expand Down Expand Up @@ -127,25 +153,48 @@ def pcr(adata, covariate, embed=None, n_comps=50, recompute_pca=True, verbose=Fa
assert embed in adata.obsm
if verbose:
print(f"Compute PCR on embedding n_comps: {n_comps}")
return pc_regression(adata.obsm[embed], covariate_values, n_comps=n_comps)
return pc_regression(
adata.obsm[embed],
covariate_values,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
)

# use existing PCA computation
elif (recompute_pca is False) and ("X_pca" in adata.obsm) and ("pca" in adata.uns):
if verbose:
print("using existing PCA")
return pc_regression(
adata.obsm["X_pca"], covariate_values, pca_var=adata.uns["pca"]["variance"]
adata.obsm["X_pca"],
covariate_values,
pca_var=adata.uns["pca"]["variance"],
linreg_method=linreg_method,
n_threads=n_threads,
)

# recompute PCA
else:
if verbose:
print(f"compute PCA n_comps: {n_comps}")
return pc_regression(adata.X, covariate_values, n_comps=n_comps)
return pc_regression(
adata.X,
covariate_values,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
)


def pc_regression(
data, covariate, pca_var=None, n_comps=50, svd_solver="arpack", verbose=False
data,
covariate,
pca_var=None,
n_comps=50,
svd_solver="arpack",
linreg_method="sklearn",
verbose=False,
n_threads=1,
):
"""Principal component regression

Expand All @@ -172,14 +221,20 @@ def pc_regression(
:return:
Variance contribution of regression
"""

if isinstance(data, (np.ndarray, sparse.csr_matrix, sparse.csc_matrix)):
matrix = data
else:
raise TypeError(
f"invalid type: {data.__class__} is not a numpy array or sparse matrix"
)

if linreg_method == "sklearn":
linreg_method = linreg_sklearn
elif linreg_method == "numpy":
linreg_method = linreg_np
else:
raise ValueError(f"invalid linreg_method: {linreg_method}")

# perform PCA if no variance contributions are given
if pca_var is None:

Expand All @@ -193,7 +248,7 @@ def pc_regression(
matrix = matrix.toarray()

if verbose:
print("compute PCA")
print("compute PCA...")
X_pca, _, _, pca_var = sc.tl.pca(
matrix,
n_comps=n_comps,
Expand All @@ -216,18 +271,39 @@ def pc_regression(
else:
if verbose:
print("one-hot encode categorical values")
covariate = pd.get_dummies(covariate)
covariate = pd.get_dummies(covariate).to_numpy()

# fit linear model for n_comps PCs
r2 = []
for i in range(n_comps):
pc = X_pca[:, [i]]
lm = LinearRegression()
lm.fit(covariate, pc)
r2_score = np.maximum(0, lm.score(covariate, pc))
r2.append(r2_score)
if verbose:
print(f"Use {n_threads} threads for regression...")
if n_threads == 1:
r2 = []
for i in tqdm(range(n_comps), total=n_comps):
r2_score = linreg_method(X=covariate, y=X_pca[:, [i]])
r2.append(np.maximum(0, r2_score))
else:
with ThreadPoolExecutor(max_workers=n_threads) as executor:
run_r2 = executor.map(linreg_method, [covariate] * n_comps, X_pca.T)
mumichae marked this conversation as resolved.
Show resolved Hide resolved
r2 = list(tqdm(run_r2, total=n_comps))

Var = pca_var / sum(pca_var) * 100
R2Var = sum(r2 * Var) / 100

return R2Var


def linreg_sklearn(X, y):
from sklearn.linear_model import LinearRegression

lm = LinearRegression()
mumichae marked this conversation as resolved.
Show resolved Hide resolved
lm.fit(X, y)
r2_score = lm.score(X, y)
np.maximum(0, r2_score)
return r2_score


def linreg_np(X, y):
mumichae marked this conversation as resolved.
Show resolved Hide resolved
coefficients, residuals, _, _ = np.linalg.lstsq(X, y, rcond=None)
tss = np.sum((y - y.mean()) ** 2)
r2_score = 1 - (residuals[0] / tss)
return np.maximum(0, r2_score)
65 changes: 51 additions & 14 deletions tests/metrics/test_pcr_metrics.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,75 @@
import pytest
from scipy.sparse import csr_matrix

import scib
from tests.common import LOGGER, add_embed, assert_near_exact


def test_pc_regression(adata):
score = scib.me.pc_regression(adata.X, adata.obs["batch"])
LOGGER.info(f"using embedding: {score}")
assert_near_exact(score, 0, diff=1e-4)
@pytest.mark.parametrize("sparse", [False, True])
def test_pc_regression(adata, sparse):
if sparse:
adata.X = csr_matrix(adata.X)
score = scib.me.pc_regression(
adata.X,
covariate=adata.obs["batch"],
n_comps=adata.n_vars,
)
LOGGER.info(score)
assert_near_exact(score, 0, diff=1e-3)


@pytest.mark.parametrize("linreg_method", ["numpy", "sklearn"])
def test_pcr_timing(adata_pca, linreg_method):
import timeit

import anndata as ad
import scanpy as sc

# scale up anndata
adata = ad.concat([adata_pca] * 100)
print(f"compute PCA on {adata.n_obs} cells...")
sc.pp.pca(adata)

def test_pc_regression_sparse(adata):
x = csr_matrix(adata.X)
score = scib.me.pc_regression(x, adata.obs["batch"], n_comps=x.shape[1])
LOGGER.info(f"using embedding: {score}")
assert_near_exact(score, 0, diff=1e-4)
runs = 10
timing = timeit.timeit(
lambda: scib.me.pcr(
adata,
covariate="celltype",
linreg_method=linreg_method,
verbose=False,
n_threads=10,
),
number=runs,
)
LOGGER.info(f"timeit: {timing:.2f}s for {runs} runs")

# test pcr value
score = scib.me.pcr(
adata_pca,
covariate="celltype",
linreg_method=linreg_method,
verbose=True,
n_threads=1,
)
LOGGER.info(score)
assert_near_exact(score, 0.33401529220865844, diff=1e-3)


def test_pcr_batch(adata):
# no PCA precomputed
def test_pcr_comparison_batch(adata):
score = scib.me.pcr_comparison(
adata, adata, covariate="batch", n_comps=50, scale=True
)
LOGGER.info(f"no PCA precomputed: {score}")
assert_near_exact(score, 0, diff=1e-6)


def test_pcr_batch_precomputed(adata_pca):
def test_pcr_comparison_batch_precomputed(adata_pca):
score = scib.me.pcr_comparison(adata_pca, adata_pca, covariate="batch", scale=True)
LOGGER.info(f"precomputed PCA: {score}")
assert_near_exact(score, 0, diff=1e-6)


def test_pcr_batch_embedding(adata):
def test_pcr_comparison_batch_embedding(adata):
# use different embedding
score = scib.me.pcr_comparison(
adata_pre=adata,
Expand All @@ -42,5 +79,5 @@ def test_pcr_batch_embedding(adata):
n_comps=50,
scale=True,
)
LOGGER.info(f"using embedding: {score}")
LOGGER.info(f"using X_emb: {score}")
mumichae marked this conversation as resolved.
Show resolved Hide resolved
assert_near_exact(score, 0, diff=1e-6)
Loading