Skip to content

Commit

Permalink
update Spateo alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Jul 7, 2024
1 parent b5b2acf commit e9aaa78
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 16 deletions.
56 changes: 51 additions & 5 deletions spateo/alignment/methods/morpho.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
coarse_rigid_alignment,
empty_cache,
get_optimal_R,
nx_torch,
)


Expand Down Expand Up @@ -80,6 +81,7 @@ def get_P(
SpatialDistMat: Union[np.ndarray, torch.Tensor],
samples_s: Optional[List[float]] = None,
outlier_variance: float = None,
label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Tuple[Any, Any, Any]:
"""Calculating the generating probability matrix P.
Expand Down Expand Up @@ -119,6 +121,7 @@ def get_P(
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_term1 = spatial_term1 * label_dist if label_dist is not None else spatial_term1
spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
spatial_term2 = spatial_outlier + nx.einsum("ij->j", spatial_term1)
spatial_P = spatial_term1 / _unsqueeze(nx)(spatial_term2, 0)
Expand All @@ -128,6 +131,7 @@ def get_P(
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
term1 = term1 * label_dist if label_dist is not None else term1
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)

Expand Down Expand Up @@ -155,6 +159,7 @@ def get_P_chunk(
outlier_variance: float = None,
chunk_size: int = 1000,
dissimilarity: str = "kl",
label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor]:
"""Calculating the generating probability matrix P.
Expand Down Expand Up @@ -183,9 +188,9 @@ def get_P_chunk(
# chunk
X_Bs = _chunk(nx, X_B, chunk_num, dim=0)
XnBs = _chunk(nx, XnB, chunk_num, dim=0)

label_dists = _chunk(nx, label_dist, chunk_num, dim=1) if label_dist is not None else [None] * len(XnBs)
Ps = []
for x_Bs, xnBs in zip(X_Bs, XnBs):
for x_Bs, xnBs, l_d in zip(X_Bs, XnBs, label_dists):
SpatialDistMat = cal_dist(XnAHat, xnBs)
GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=x_Bs, dissimilarity=dissimilarity)
if outlier_variance is None:
Expand All @@ -197,6 +202,7 @@ def get_P_chunk(
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_term1 = spatial_term1 * l_d if l_d is not None else spatial_term1
spatial_outlier = (
_power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
)
Expand All @@ -206,6 +212,7 @@ def get_P_chunk(
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
term1 = term1 * l_d if l_d is not None else term1
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)
Ps.append(P)
Expand Down Expand Up @@ -239,6 +246,12 @@ def BA_align(
SVI_mode: bool = True,
batch_size: int = 1000,
partial_robust_level: float = 25,
use_rep: Optional[str] = None,
use_label_prior: bool = False,
label_key: Optional[str] = "cluster",
label_transfer_prior: Optional[dict] = None,
beta2: Optional[float] = None,
beta2_end: Optional[float] = None,
) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]:
"""core function for spateo pairwise alignment
Expand Down Expand Up @@ -293,8 +306,26 @@ def BA_align(
dtype=dtype,
device=device,
verbose=verbose,
use_rep=use_rep,
)

if use_label_prior:
catB = sampleA.obs[label_key].cat.categories.tolist()
catA = sampleB.obs[label_key].cat.categories.tolist()
label_transfer = np.zeros((len(catA), len(catB)), dtype=np.float32)

for j, ca in enumerate(catA):
for k, cb in enumerate(catB):
label_transfer[j, k] = label_transfer_prior[ca][cb]
label_transfer = nx.from_numpy(label_transfer, type_as=type_as)
labelA = nx.from_numpy(np.array(sampleB.obs[label_key].cat.codes.values, dtype=np.int64))
labelB = nx.from_numpy(np.array(sampleA.obs[label_key].cat.codes.values, dtype=np.int64))
if nx_torch(nx):
labelA = labelA.to(type_as.device)
labelB = labelB.to(type_as.device)
else:
label_transfer, labelA, labelB = None, None, None

coordsA, coordsB = spatial_coords[1], spatial_coords[0]
X_A, X_B = exp_matrices[1], exp_matrices[0]
del spatial_coords, exp_matrices
Expand Down Expand Up @@ -400,20 +431,29 @@ def BA_align(
R = _identity(nx, D, type_as)
minGeneDistMat = nx.min(GeneDistMat, 1)
# Automatically determine the value of beta2
beta2 = minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5
beta2_end = nx.max(minGeneDistMat) / 5
beta2 = (
minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5
if beta2 is None
else _data(nx, data=beta2, type_as=type_as)
)
beta2_end = nx.max(minGeneDistMat) / 5 if beta2_end is None else _data(nx, data=beta2_end, type_as=type_as)
del minGeneDistMat
if sub_sample:
del sub_X_A, sub_X_B, GeneDistMat
# The value of beta2 becomes progressively larger
beta2 = nx.maximum(beta2, _data(nx, 1e-2, type_as))
beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50))

print("beta2: {} --> {}".format(beta2, beta2_end))
# Use smaller spatial variance to reduce tails
outlier_variance = 1
max_outlier_variance = partial_robust_level # 20
outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2))

if use_label_prior:
label_dist = label_transfer[labelA, :][:, labelB]
else:
label_dist = None

if SVI_mode:
SVI_deacy = _data(nx, 10.0, type_as)
# Select a random subset of data
Expand All @@ -428,6 +468,7 @@ def BA_align(
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = SpatialDistMat[:, randIdx] # NA x batch_size
randlabel_dist = label_dist[:, randIdx] if use_label_prior else None
Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0
SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K
PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D
Expand Down Expand Up @@ -462,6 +503,7 @@ def BA_align(
GeneDistMat=randGeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
label_dist=randlabel_dist,
)
else:
P, spatial_P, sigma2_P = get_P(
Expand All @@ -475,6 +517,7 @@ def BA_align(
GeneDistMat=GeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
label_dist=label_dist,
)

if iter > 5:
Expand Down Expand Up @@ -624,6 +667,7 @@ def BA_align(
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = cal_dist(XAHat, randcoordsB)
randlabel_dist = label_dist[:, randIdx] if use_label_prior else None
empty_cache(device=device)

# full data
Expand All @@ -639,6 +683,8 @@ def BA_align(
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_dist=label_dist,
dissimilarity=dissimilarity,
)
# Get optimal Rigid transformation
optimal_RnA, optimal_R, optimal_t = get_optimal_R(
Expand Down
2 changes: 1 addition & 1 deletion spateo/alignment/methods/morpho_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def BA_align_sparse(
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
top_k=32,
top_k=512,
dissimilarity=dissimilarity,
batch_capacity=batch_capacity,
)
Expand Down
59 changes: 53 additions & 6 deletions spateo/alignment/methods/morpho_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def calc_distance(
"square_euc",
"square_euclidean",
"kl",
], "``metric`` value is wrong. Available ``metric`` are: ``'euc'``, ``'euclidean'``, ``'square_euc'``, ``'square_euclidean'``, and ``'kl'``."
"cos",
"cosine",
], "``metric`` value is wrong. Available ``metric`` are: ``'euc'``, ``'euclidean'``, ``'square_euc'``, ``'square_euclidean'``, ``'kl'``, ``'cos'``, and ``'cosine'``."

if use_sparse:
assert sparse_method in [
Expand Down Expand Up @@ -327,18 +329,63 @@ def _SparseTensor(nx, row, col, value, sparse_sizes):
return coo_array((value, (row, col)), shape=sparse_sizes)


def _cosine_distance_backend(
X: Union[np.ndarray, torch.Tensor],
Y: Union[np.ndarray, torch.Tensor],
eps: float = 1e-8,
) -> Union[np.ndarray, torch.Tensor]:
"""
Compute the pairwise cosine similarity between all pairs of samples in matrices X and Y.
Parameters
----------
X : np.ndarray or torch.Tensor
Matrix with shape (N, D), where each row represents a sample.
Y : np.ndarray or torch.Tensor
Matrix with shape (M, D), where each row represents a sample.
eps : float, optional
A small value to avoid division by zero. Default is 1e-8.
Returns
-------
np.ndarray or torch.Tensor
Pairwise cosine similarity matrix with shape (N, M).
Raises
------
AssertionError
If the number of features in X and Y do not match.
"""

assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

# Get the appropriate backend (either NumPy or PyTorch)
nx = ot.backend.get_backend(X, Y)

# Normalize rows to unit vectors
X_norm = nx.sqrt(nx.sum(X**2, axis=1, keepdims=True))
Y_norm = nx.sqrt(nx.sum(Y**2, axis=1, keepdims=True))
X = X / nx.maximum(X_norm, eps)
Y = Y / nx.maximum(Y_norm, eps)

# Compute cosine similarity
D = nx.dot(X, Y.T)

return D


def _cos_similarity(
mat1: Union[np.ndarray, torch.Tensor],
mat2: Union[np.ndarray, torch.Tensor],
):
nx = ot.backend.get_backend(mat1, mat2)
if nx_torch(nx):
torch_cos = torch.nn.CosineSimilarity(dim=1)
mat1_unsqueeze = mat1.unsqueeze(-1)
mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2)
distMat = -torch_cos(mat1_unsqueeze, mat2_unsqueeze) * 0.5 + 0.5
# torch_cos = torch.nn.CosineSimilarity(dim=1)
# mat1_unsqueeze = mat1.unsqueeze(-1)
# mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2)
distMat = -_cosine_distance_backend(mat1, mat2) * 0.5 + 0.5
else:
distMat = (ot.dist(mat1, mat2, metric="cosine")) * 0.5
distMat = (-ot.dist(mat1, mat2, metric="cosine") + 1) * 0.5 + 0.5
return distMat


Expand Down
53 changes: 49 additions & 4 deletions spateo/alignment/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,16 +712,61 @@ def get_optimal_R(
###############################


def _cosine_distance_backend(
X: Union[np.ndarray, torch.Tensor],
Y: Union[np.ndarray, torch.Tensor],
eps: float = 1e-8,
) -> Union[np.ndarray, torch.Tensor]:
"""
Compute the pairwise cosine similarity between all pairs of samples in matrices X and Y.
Parameters
----------
X : np.ndarray or torch.Tensor
Matrix with shape (N, D), where each row represents a sample.
Y : np.ndarray or torch.Tensor
Matrix with shape (M, D), where each row represents a sample.
eps : float, optional
A small value to avoid division by zero. Default is 1e-8.
Returns
-------
np.ndarray or torch.Tensor
Pairwise cosine similarity matrix with shape (N, M).
Raises
------
AssertionError
If the number of features in X and Y do not match.
"""

assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

# Get the appropriate backend (either NumPy or PyTorch)
nx = ot.backend.get_backend(X, Y)

# Normalize rows to unit vectors
X_norm = nx.sqrt(nx.sum(X**2, axis=1, keepdims=True))
Y_norm = nx.sqrt(nx.sum(Y**2, axis=1, keepdims=True))
X = X / nx.maximum(X_norm, eps)
Y = Y / nx.maximum(Y_norm, eps)

# Compute cosine similarity
D = nx.dot(X, Y.T)

return D


def _cos_similarity(
mat1: Union[np.ndarray, torch.Tensor],
mat2: Union[np.ndarray, torch.Tensor],
):
nx = ot.backend.get_backend(mat1, mat2)
if nx_torch(nx):
torch_cos = torch.nn.CosineSimilarity(dim=1)
mat1_unsqueeze = mat1.unsqueeze(-1)
mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2)
distMat = torch_cos(mat1_unsqueeze, mat2_unsqueeze) * 0.5 + 0.5
# torch_cos = torch.nn.CosineSimilarity(dim=1)
# mat1_unsqueeze = mat1.unsqueeze(-1)
# mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2)
distMat = -_cosine_distance_backend(mat1, mat2) * 0.5 + 0.5
else:
distMat = (-ot.dist(mat1, mat2, metric="cosine") + 1) * 0.5 + 0.5
return distMat
Expand Down

0 comments on commit e9aaa78

Please sign in to comment.