Skip to content

Commit

Permalink
upgrade to v1.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
mrland99 committed Apr 9, 2022
1 parent d166c11 commit 62fbc4c
Show file tree
Hide file tree
Showing 14 changed files with 399 additions and 149 deletions.
130 changes: 83 additions & 47 deletions build/lib/paste/PASTE.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,53 @@
from typing import List, Tuple, Optional
import numpy as np
import anndata
from anndata import AnnData
import ot
from sklearn.decomposition import NMF
from .helper import kl_divergence, intersect, kl_divergence_backend, to_dense_array, extract_data_matrix
import time
from .helper import intersect, kl_divergence_backend, to_dense_array, extract_data_matrix

def pairwise_align(sliceA, sliceB, alpha = 0.1, dissimilarity='kl', use_rep = None, G_init = None, a_distribution = None, b_distribution = None, norm = False, numItermax = 200, backend=ot.backend.NumpyBackend(), use_gpu = False, return_obj = False, verbose = False, gpu_verbose = True, **kwargs):
def pairwise_align(
sliceA: AnnData,
sliceB: AnnData,
alpha: float = 0.1,
dissimilarity: str ='kl',
use_rep: Optional[str] = None,
G_init = None,
a_distribution = None,
b_distribution = None,
norm: bool = False,
numItermax: int = 200,
backend = ot.backend.NumpyBackend(),
use_gpu: bool = False,
return_obj: bool = False,
verbose: bool = False,
gpu_verbose: bool = True,
**kwargs) -> Tuple[np.ndarray, Optional[int]]:
"""
Calculates and returns optimal alignment of two slices.
param: sliceA - AnnData object of spatial slice
param: sliceB - AnnData object of spatial slice
param: alpha - Alignment tuning parameter. Note: 0 ≤ alpha ≤ 1
param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
param: G_init - initial mapping to be used in FGW-OT, otherwise default is uniform mapping
param: a_distribution - distribution of sliceA spots (1-d numpy array), otherwise default is uniform
param: b_distribution - distribution of sliceB spots (1-d numpy array), otherwise default is uniform
param: numItermax - max number of iterations during FGW-OT
param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
param: return_obj - returns objective function output of FGW-OT if True, nothing if False
param: verbose - FGW-OT is verbose if True, nothing if False
param: gpu_verbose - Print whether gpu is being used to user, nothing if False
Args:
sliceA: Slice A to align.
sliceB: Slice B to align.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``.
G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping.
a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform.
b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform.
numItermax: Max number of iterations during FGW-OT.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
return_obj: If ``True``, additionally returns objective function output of FGW-OT.
verbose: If ``True``, FGW-OT is verbose.
gpu_verbose: If ``True``, print whether gpu is being used to user.
return: pi - alignment of spots
return: log['fgw_dist'] - objective function output of FGW-OT
Returns:
- Alignment of spots.
If ``return_obj = True``, additionally returns:
- Objective function output of FGW-OT.
"""

# Determine if gpu or cpu is being used
Expand Down Expand Up @@ -131,31 +151,47 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, dissimilarity='kl', use_rep = No
return pi


def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, threshold = 0.001, max_iter = 10, dissimilarity='kl', use_rep = None, norm = False, random_seed = None, pis_init = None, distributions=None, backend = ot.backend.NumpyBackend(), use_gpu = False, verbose = False, gpu_verbose = True):
def center_align(
A: AnnData,
slices: List[AnnData],
lmbda = None,
alpha: float = 0.1,
n_components: int = 15,
threshold: float = 0.001,
max_iter: int = 10,
dissimilarity: str ='kl',
norm: bool = False,
random_seed: Optional[int] = None,
pis_init: Optional[List[np.ndarray]] = None,
distributions = None,
backend = ot.backend.NumpyBackend(),
use_gpu: bool = False,
verbose: bool = False,
gpu_verbose: bool = True) -> Tuple[AnnData, List[np.ndarray]]:
"""
Computes center alignment of slices.
param: A - Initialization of starting AnnData Spatial Object; Make sure to include gene expression AND spatial info
param: slices - List of slices (AnnData objects) used to calculate center alignment
param: lmbda - List of probability weights assigned to each slice; default is uniform weights
param: n_components - Number of components in NMF decomposition
param: threshold - Threshold for convergence of W and H
param: max_iter - maximum number of iterations for solving for center slice
param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
param: random_seed - set random seed for reproducibility
param: pis_init - initial list of mappings between 'A' and 'slices' to solver, otherwise will calculate default mappings
param: distributions - distributions of spots for each slice (list of 1-d numpy array), otherwise default is uniform
param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
param: verbose - FGW-OT is verbose if True, nothing if False
param: gpu_verbose - Print whether gpu is being used to user, nothing if False
Args:
A: Slice to use as the initialization for center alignment; Make sure to include gene expression and spatial information.
slices: List of slices to use in the center alignment.
lmbda (array-like, optional): List of probability weights assigned to each slice; If ``None``, use uniform weights.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
n_components: Number of components in NMF decomposition.
threshold: Threshold for convergence of W and H during NMF decomposition.
max_iter: Maximum number of iterations for our center alignment algorithm.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
random_seed: Set random seed for reproducibility.
pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings.
distributions (List[array-like], optional): Distributions of spots for each slice. Otherwise, default is uniform.
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
verbose: If ``True``, FGW-OT is verbose.
gpu_verbose: If ``True``, print whether gpu is being used to user.
return: center_slice - inferred center slice (AnnData object) with full and low dimensional representations (W, H) of
the gene expression matrix
return: pi - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns)
Returns:
- Inferred center slice with full and low dimensional representations (W, H) of the gene expression matrix.
- List of pairwise alignment mappings of the center slice (rows) to each input slice (columns).
"""

# Determine if gpu or cpu is being used
Expand Down Expand Up @@ -213,10 +249,10 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
center_coordinates = A.obsm['spatial']

if not isinstance(center_coordinates, np.ndarray):
print("Warning: A.obsm['spatial'] is not of type numpy array .")
print("Warning: A.obsm['spatial'] is not of type numpy array.")

# Initialize center_slice
center_slice = anndata.AnnData(np.dot(W,H))
center_slice = AnnData(np.dot(W,H))
center_slice.var.index = common_genes
center_slice.obs.index = A.obs.index
center_slice.obsm['spatial'] = center_coordinates
Expand Down Expand Up @@ -246,7 +282,7 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
#--------------------------- HELPER METHODS -----------------------------------

def center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity = 'kl', norm = False, G_inits = None, distributions=None, verbose = False):
center_slice = anndata.AnnData(np.dot(W,H))
center_slice = AnnData(np.dot(W,H))
center_slice.var.index = common_genes
center_slice.obsm['spatial'] = center_coordinates

Expand Down Expand Up @@ -313,4 +349,4 @@ def df(G):
return res, log

else:
return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
2 changes: 1 addition & 1 deletion build/lib/paste/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .PASTE import pairwise_align, center_align
from .helper import kl_divergence, kl_divergence_backend, intersect, match_spots_using_spatial_heuristic, filter_for_common_genes
from .helper import match_spots_using_spatial_heuristic
from .visualization import plot_slice, stack_slices_pairwise, stack_slices_center
89 changes: 48 additions & 41 deletions build/lib/paste/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,44 @@
import scipy
import ot

def filter_for_common_genes(slices):
def match_spots_using_spatial_heuristic(
X,
Y,
use_ot: bool = True) -> np.ndarray:
"""
param: slices - list of slices (AnnData objects)
"""
assert len(slices) > 0, "Cannot have empty list."
Calculates and returns a mapping of spots using a spatial heuristic.
Args:
X (array-like, optional): Coordinates for spots X.
Y (array-like, optional): Coordinates for spots Y.
use_ot: If ``True``, use optimal transport ``ot.emd()`` to calculate mapping. Otherwise, use Scipy's ``min_weight_full_bipartite_matching()`` algorithm.
common_genes = slices[0].var.index
for s in slices:
common_genes = intersect(common_genes, s.var.index)
for i in range(len(slices)):
slices[i] = slices[i][:, common_genes]
print('Filtered all slices for common genes. There are ' + str(len(common_genes)) + ' common genes.')
Returns:
Mapping of spots using a spatial heuristic.
"""
n1,n2=len(X),len(Y)
X,Y = norm_and_center_coordinates(X),norm_and_center_coordinates(Y)
dist = scipy.spatial.distance_matrix(X,Y)
if use_ot:
pi = ot.emd(np.ones(n1)/n1, np.ones(n2)/n2, dist)
else:
row_ind, col_ind = scipy.sparse.csgraph.min_weight_full_bipartite_matching(scipy.sparse.csr_matrix(dist))
pi = np.zeros((n1,n2))
pi[row_ind, col_ind] = 1/max(n1,n2)
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
return pi

def kl_divergence(X, Y):
"""
Returns pairwise KL divergence (over all pairs of samples) of two matrices X and Y.
param: X - np array with dim (n_samples by n_features)
param: Y - np array with dim (m_samples by n_features)
Args:
X: np array with dim (n_samples by n_features)
Y: np array with dim (m_samples by n_features)
return: D - np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
Returns:
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

Expand All @@ -40,10 +57,12 @@ def kl_divergence_backend(X, Y):
Takes advantage of POT backend to speed up computation.
param: X - np array with dim (n_samples by n_features)
param: Y - np array with dim (m_samples by n_features)
Args:
X: np array with dim (n_samples by n_features)
Y: np array with dim (m_samples by n_features)
return: D - np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
Returns:
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

Expand All @@ -61,10 +80,14 @@ def kl_divergence_backend(X, Y):

def intersect(lst1, lst2):
"""
param: lst1 - list
param: lst2 - list
Gets and returns intersection of two lists.
Args:
lst1: List
lst2: List
return: list of common elements
Returns:
lst3: List of common elements.
"""

temp = set(lst2)
Expand All @@ -73,33 +96,17 @@ def intersect(lst1, lst2):

def norm_and_center_coordinates(X):
"""
param: X - numpy array
Normalizes and centers coordinates at the origin.
Args:
X: Numpy array
return:
Returns:
X_new: Updated coordiantes.
"""
return (X-X.mean(axis=0))/min(scipy.spatial.distance.pdist(X))


def match_spots_using_spatial_heuristic(X,Y,use_ot=True):
"""
param: X - numpy array
param: Y - numpy array
return: pi- mapping of spots using spatial heuristic
"""
n1,n2=len(X),len(Y)
X,Y = norm_and_center_coordinates(X),norm_and_center_coordinates(Y)
dist = scipy.spatial.distance_matrix(X,Y)
if use_ot:
pi = ot.emd(np.ones(n1)/n1, np.ones(n2)/n2, dist)
else:
row_ind, col_ind = scipy.sparse.csgraph.min_weight_full_bipartite_matching(scipy.sparse.csr_matrix(dist))
pi = np.zeros((n1,n2))
pi[row_ind, col_ind] = 1/max(n1,n2)
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
return pi

## Covert a sparse matrix into a dense np array
to_dense_array = lambda X: X.toarray() if isinstance(X,scipy.sparse.csr.spmatrix) else np.array(X)

Expand Down
Loading

0 comments on commit 62fbc4c

Please sign in to comment.