From e9aaa785fae76184ed7c90be9dbc3e2806ff5953 Mon Sep 17 00:00:00 2001 From: YifanLu2000 Date: Sun, 7 Jul 2024 04:34:10 +0000 Subject: [PATCH] update Spateo alignment --- spateo/alignment/methods/morpho.py | 56 ++++++++++++++++-- spateo/alignment/methods/morpho_sparse.py | 2 +- .../alignment/methods/morpho_sparse_utils.py | 59 +++++++++++++++++-- spateo/alignment/methods/utils.py | 53 +++++++++++++++-- 4 files changed, 154 insertions(+), 16 deletions(-) diff --git a/spateo/alignment/methods/morpho.py b/spateo/alignment/methods/morpho.py index d3c82283..3c352cd3 100755 --- a/spateo/alignment/methods/morpho.py +++ b/spateo/alignment/methods/morpho.py @@ -36,6 +36,7 @@ coarse_rigid_alignment, empty_cache, get_optimal_R, + nx_torch, ) @@ -80,6 +81,7 @@ def get_P( SpatialDistMat: Union[np.ndarray, torch.Tensor], samples_s: Optional[List[float]] = None, outlier_variance: float = None, + label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> Tuple[Any, Any, Any]: """Calculating the generating probability matrix P. @@ -119,6 +121,7 @@ def get_P( exp_SpatialMat, (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) + spatial_term1 = spatial_term1 * label_dist if label_dist is not None else spatial_term1 spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s) spatial_term2 = spatial_outlier + nx.einsum("ij->j", spatial_term1) spatial_P = spatial_term1 / _unsqueeze(nx)(spatial_term2, 0) @@ -128,6 +131,7 @@ def get_P( _mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))), (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) + term1 = term1 * label_dist if label_dist is not None else term1 P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8) P = nx.einsum("j,ij->ij", spatial_inlier, P) @@ -155,6 +159,7 @@ def get_P_chunk( outlier_variance: float = None, chunk_size: int = 1000, dissimilarity: str = "kl", + label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> Union[np.ndarray, torch.Tensor]: """Calculating the generating probability matrix P. @@ -183,9 +188,9 @@ def get_P_chunk( # chunk X_Bs = _chunk(nx, X_B, chunk_num, dim=0) XnBs = _chunk(nx, XnB, chunk_num, dim=0) - + label_dists = _chunk(nx, label_dist, chunk_num, dim=1) if label_dist is not None else [None] * len(XnBs) Ps = [] - for x_Bs, xnBs in zip(X_Bs, XnBs): + for x_Bs, xnBs, l_d in zip(X_Bs, XnBs, label_dists): SpatialDistMat = cal_dist(XnAHat, xnBs) GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=x_Bs, dissimilarity=dissimilarity) if outlier_variance is None: @@ -197,6 +202,7 @@ def get_P_chunk( exp_SpatialMat, (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) + spatial_term1 = spatial_term1 * l_d if l_d is not None else spatial_term1 spatial_outlier = ( _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s) ) @@ -206,6 +212,7 @@ def get_P_chunk( _mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))), (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) + term1 = term1 * l_d if l_d is not None else term1 P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8) P = nx.einsum("j,ij->ij", spatial_inlier, P) Ps.append(P) @@ -239,6 +246,12 @@ def BA_align( SVI_mode: bool = True, batch_size: int = 1000, partial_robust_level: float = 25, + use_rep: Optional[str] = None, + use_label_prior: bool = False, + label_key: Optional[str] = "cluster", + label_transfer_prior: Optional[dict] = None, + beta2: Optional[float] = None, + beta2_end: Optional[float] = None, ) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]: """core function for spateo pairwise alignment @@ -293,8 +306,26 @@ def BA_align( dtype=dtype, device=device, verbose=verbose, + use_rep=use_rep, ) + if use_label_prior: + catB = sampleA.obs[label_key].cat.categories.tolist() + catA = sampleB.obs[label_key].cat.categories.tolist() + label_transfer = np.zeros((len(catA), len(catB)), dtype=np.float32) + + for j, ca in enumerate(catA): + for k, cb in enumerate(catB): + label_transfer[j, k] = label_transfer_prior[ca][cb] + label_transfer = nx.from_numpy(label_transfer, type_as=type_as) + labelA = nx.from_numpy(np.array(sampleB.obs[label_key].cat.codes.values, dtype=np.int64)) + labelB = nx.from_numpy(np.array(sampleA.obs[label_key].cat.codes.values, dtype=np.int64)) + if nx_torch(nx): + labelA = labelA.to(type_as.device) + labelB = labelB.to(type_as.device) + else: + label_transfer, labelA, labelB = None, None, None + coordsA, coordsB = spatial_coords[1], spatial_coords[0] X_A, X_B = exp_matrices[1], exp_matrices[0] del spatial_coords, exp_matrices @@ -400,20 +431,29 @@ def BA_align( R = _identity(nx, D, type_as) minGeneDistMat = nx.min(GeneDistMat, 1) # Automatically determine the value of beta2 - beta2 = minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5 - beta2_end = nx.max(minGeneDistMat) / 5 + beta2 = ( + minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5 + if beta2 is None + else _data(nx, data=beta2, type_as=type_as) + ) + beta2_end = nx.max(minGeneDistMat) / 5 if beta2_end is None else _data(nx, data=beta2_end, type_as=type_as) del minGeneDistMat if sub_sample: del sub_X_A, sub_X_B, GeneDistMat # The value of beta2 becomes progressively larger beta2 = nx.maximum(beta2, _data(nx, 1e-2, type_as)) beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50)) - + print("beta2: {} --> {}".format(beta2, beta2_end)) # Use smaller spatial variance to reduce tails outlier_variance = 1 max_outlier_variance = partial_robust_level # 20 outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2)) + if use_label_prior: + label_dist = label_transfer[labelA, :][:, labelB] + else: + label_dist = None + if SVI_mode: SVI_deacy = _data(nx, 10.0, type_as) # Select a random subset of data @@ -428,6 +468,7 @@ def BA_align( else: randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size SpatialDistMat = SpatialDistMat[:, randIdx] # NA x batch_size + randlabel_dist = label_dist[:, randIdx] if use_label_prior else None Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0 SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D @@ -462,6 +503,7 @@ def BA_align( GeneDistMat=randGeneDistMat, SpatialDistMat=SpatialDistMat, outlier_variance=outlier_variance, + label_dist=randlabel_dist, ) else: P, spatial_P, sigma2_P = get_P( @@ -475,6 +517,7 @@ def BA_align( GeneDistMat=GeneDistMat, SpatialDistMat=SpatialDistMat, outlier_variance=outlier_variance, + label_dist=label_dist, ) if iter > 5: @@ -624,6 +667,7 @@ def BA_align( else: randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size SpatialDistMat = cal_dist(XAHat, randcoordsB) + randlabel_dist = label_dist[:, randIdx] if use_label_prior else None empty_cache(device=device) # full data @@ -639,6 +683,8 @@ def BA_align( gamma=gamma, Sigma=SigmaDiag, outlier_variance=outlier_variance, + label_dist=label_dist, + dissimilarity=dissimilarity, ) # Get optimal Rigid transformation optimal_RnA, optimal_R, optimal_t = get_optimal_R( diff --git a/spateo/alignment/methods/morpho_sparse.py b/spateo/alignment/methods/morpho_sparse.py index b5454df8..22467ee6 100644 --- a/spateo/alignment/methods/morpho_sparse.py +++ b/spateo/alignment/methods/morpho_sparse.py @@ -547,7 +547,7 @@ def BA_align_sparse( Sigma=SigmaDiag, outlier_variance=outlier_variance, label_transfer_prior=label_transfer_prior, - top_k=32, + top_k=512, dissimilarity=dissimilarity, batch_capacity=batch_capacity, ) diff --git a/spateo/alignment/methods/morpho_sparse_utils.py b/spateo/alignment/methods/morpho_sparse_utils.py index cea8a2e3..3625e220 100644 --- a/spateo/alignment/methods/morpho_sparse_utils.py +++ b/spateo/alignment/methods/morpho_sparse_utils.py @@ -29,7 +29,9 @@ def calc_distance( "square_euc", "square_euclidean", "kl", - ], "``metric`` value is wrong. Available ``metric`` are: ``'euc'``, ``'euclidean'``, ``'square_euc'``, ``'square_euclidean'``, and ``'kl'``." + "cos", + "cosine", + ], "``metric`` value is wrong. Available ``metric`` are: ``'euc'``, ``'euclidean'``, ``'square_euc'``, ``'square_euclidean'``, ``'kl'``, ``'cos'``, and ``'cosine'``." if use_sparse: assert sparse_method in [ @@ -327,18 +329,63 @@ def _SparseTensor(nx, row, col, value, sparse_sizes): return coo_array((value, (row, col)), shape=sparse_sizes) +def _cosine_distance_backend( + X: Union[np.ndarray, torch.Tensor], + Y: Union[np.ndarray, torch.Tensor], + eps: float = 1e-8, +) -> Union[np.ndarray, torch.Tensor]: + """ + Compute the pairwise cosine similarity between all pairs of samples in matrices X and Y. + + Parameters + ---------- + X : np.ndarray or torch.Tensor + Matrix with shape (N, D), where each row represents a sample. + Y : np.ndarray or torch.Tensor + Matrix with shape (M, D), where each row represents a sample. + eps : float, optional + A small value to avoid division by zero. Default is 1e-8. + + Returns + ------- + np.ndarray or torch.Tensor + Pairwise cosine similarity matrix with shape (N, M). + + Raises + ------ + AssertionError + If the number of features in X and Y do not match. + """ + + assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." + + # Get the appropriate backend (either NumPy or PyTorch) + nx = ot.backend.get_backend(X, Y) + + # Normalize rows to unit vectors + X_norm = nx.sqrt(nx.sum(X**2, axis=1, keepdims=True)) + Y_norm = nx.sqrt(nx.sum(Y**2, axis=1, keepdims=True)) + X = X / nx.maximum(X_norm, eps) + Y = Y / nx.maximum(Y_norm, eps) + + # Compute cosine similarity + D = nx.dot(X, Y.T) + + return D + + def _cos_similarity( mat1: Union[np.ndarray, torch.Tensor], mat2: Union[np.ndarray, torch.Tensor], ): nx = ot.backend.get_backend(mat1, mat2) if nx_torch(nx): - torch_cos = torch.nn.CosineSimilarity(dim=1) - mat1_unsqueeze = mat1.unsqueeze(-1) - mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2) - distMat = -torch_cos(mat1_unsqueeze, mat2_unsqueeze) * 0.5 + 0.5 + # torch_cos = torch.nn.CosineSimilarity(dim=1) + # mat1_unsqueeze = mat1.unsqueeze(-1) + # mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2) + distMat = -_cosine_distance_backend(mat1, mat2) * 0.5 + 0.5 else: - distMat = (ot.dist(mat1, mat2, metric="cosine")) * 0.5 + distMat = (-ot.dist(mat1, mat2, metric="cosine") + 1) * 0.5 + 0.5 return distMat diff --git a/spateo/alignment/methods/utils.py b/spateo/alignment/methods/utils.py index 1a0ccfbe..ca3816cd 100755 --- a/spateo/alignment/methods/utils.py +++ b/spateo/alignment/methods/utils.py @@ -712,16 +712,61 @@ def get_optimal_R( ############################### +def _cosine_distance_backend( + X: Union[np.ndarray, torch.Tensor], + Y: Union[np.ndarray, torch.Tensor], + eps: float = 1e-8, +) -> Union[np.ndarray, torch.Tensor]: + """ + Compute the pairwise cosine similarity between all pairs of samples in matrices X and Y. + + Parameters + ---------- + X : np.ndarray or torch.Tensor + Matrix with shape (N, D), where each row represents a sample. + Y : np.ndarray or torch.Tensor + Matrix with shape (M, D), where each row represents a sample. + eps : float, optional + A small value to avoid division by zero. Default is 1e-8. + + Returns + ------- + np.ndarray or torch.Tensor + Pairwise cosine similarity matrix with shape (N, M). + + Raises + ------ + AssertionError + If the number of features in X and Y do not match. + """ + + assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." + + # Get the appropriate backend (either NumPy or PyTorch) + nx = ot.backend.get_backend(X, Y) + + # Normalize rows to unit vectors + X_norm = nx.sqrt(nx.sum(X**2, axis=1, keepdims=True)) + Y_norm = nx.sqrt(nx.sum(Y**2, axis=1, keepdims=True)) + X = X / nx.maximum(X_norm, eps) + Y = Y / nx.maximum(Y_norm, eps) + + # Compute cosine similarity + D = nx.dot(X, Y.T) + + return D + + def _cos_similarity( mat1: Union[np.ndarray, torch.Tensor], mat2: Union[np.ndarray, torch.Tensor], ): nx = ot.backend.get_backend(mat1, mat2) if nx_torch(nx): - torch_cos = torch.nn.CosineSimilarity(dim=1) - mat1_unsqueeze = mat1.unsqueeze(-1) - mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2) - distMat = torch_cos(mat1_unsqueeze, mat2_unsqueeze) * 0.5 + 0.5 + # torch_cos = torch.nn.CosineSimilarity(dim=1) + # mat1_unsqueeze = mat1.unsqueeze(-1) + # mat2_unsqueeze = mat2.unsqueeze(-1).transpose(0, 2) + distMat = -_cosine_distance_backend(mat1, mat2) * 0.5 + 0.5 else: distMat = (-ot.dist(mat1, mat2, metric="cosine") + 1) * 0.5 + 0.5 return distMat