Skip to content

Commit

Permalink
Merge branch 'main' into update
Browse files Browse the repository at this point in the history
  • Loading branch information
Sichao25 authored Jul 23, 2024
2 parents 1d4dec2 + 5b4eb90 commit 2521a7c
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 50 deletions.
77 changes: 69 additions & 8 deletions spateo/alignment/methods/morpho.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
coarse_rigid_alignment,
empty_cache,
get_optimal_R,
nx_torch,
)


Expand Down Expand Up @@ -80,6 +81,7 @@ def get_P(
SpatialDistMat: Union[np.ndarray, torch.Tensor],
samples_s: Optional[List[float]] = None,
outlier_variance: float = None,
label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Tuple[Any, Any, Any]:
"""Calculating the generating probability matrix P.
Expand Down Expand Up @@ -119,6 +121,7 @@ def get_P(
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_term1 = spatial_term1 * label_dist if label_dist is not None else spatial_term1
spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
spatial_term2 = spatial_outlier + nx.einsum("ij->j", spatial_term1)
spatial_P = spatial_term1 / _unsqueeze(nx)(spatial_term2, 0)
Expand All @@ -128,6 +131,7 @@ def get_P(
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
term1 = term1 * label_dist if label_dist is not None else term1
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)

Expand Down Expand Up @@ -155,6 +159,7 @@ def get_P_chunk(
outlier_variance: float = None,
chunk_size: int = 1000,
dissimilarity: str = "kl",
label_dist: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor]:
"""Calculating the generating probability matrix P.
Expand Down Expand Up @@ -183,9 +188,9 @@ def get_P_chunk(
# chunk
X_Bs = _chunk(nx, X_B, chunk_num, dim=0)
XnBs = _chunk(nx, XnB, chunk_num, dim=0)

label_dists = _chunk(nx, label_dist, chunk_num, dim=1) if label_dist is not None else [None] * len(XnBs)
Ps = []
for x_Bs, xnBs in zip(X_Bs, XnBs):
for x_Bs, xnBs, l_d in zip(X_Bs, XnBs, label_dists):
SpatialDistMat = cal_dist(XnAHat, xnBs)
GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=x_Bs, dissimilarity=dissimilarity)
if outlier_variance is None:
Expand All @@ -197,6 +202,7 @@ def get_P_chunk(
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_term1 = spatial_term1 * l_d if l_d is not None else spatial_term1
spatial_outlier = (
_power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
)
Expand All @@ -206,6 +212,7 @@ def get_P_chunk(
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
term1 = term1 * l_d if l_d is not None else term1
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)
Ps.append(P)
Expand Down Expand Up @@ -239,8 +246,15 @@ def BA_align(
SVI_mode: bool = True,
batch_size: int = 1000,
partial_robust_level: float = 25,
use_rep: Optional[str] = None,
use_label_prior: bool = False,
label_key: Optional[str] = "cluster",
label_transfer_prior: Optional[dict] = None,
beta2: Optional[float] = None,
beta2_end: Optional[float] = None,
sigma2_end: Optional[float] = None,
) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]:
"""_summary_
"""core function for spateo pairwise alignment
Args:
sampleA: Sample A that acts as reference.
Expand Down Expand Up @@ -293,8 +307,26 @@ def BA_align(
dtype=dtype,
device=device,
verbose=verbose,
use_rep=use_rep,
)

if use_label_prior:
catB = sampleA.obs[label_key].cat.categories.tolist()
catA = sampleB.obs[label_key].cat.categories.tolist()
label_transfer = np.zeros((len(catA), len(catB)), dtype=np.float32)

for j, ca in enumerate(catA):
for k, cb in enumerate(catB):
label_transfer[j, k] = label_transfer_prior[ca][cb]
label_transfer = nx.from_numpy(label_transfer, type_as=type_as)
labelA = nx.from_numpy(np.array(sampleB.obs[label_key].cat.codes.values, dtype=np.int64))
labelB = nx.from_numpy(np.array(sampleA.obs[label_key].cat.codes.values, dtype=np.int64))
if nx_torch(nx):
labelA = labelA.to(type_as.device)
labelB = labelB.to(type_as.device)
else:
label_transfer, labelA, labelB = None, None, None

coordsA, coordsB = spatial_coords[1], spatial_coords[0]
X_A, X_B = exp_matrices[1], exp_matrices[0]
del spatial_coords, exp_matrices
Expand Down Expand Up @@ -350,6 +382,17 @@ def BA_align(
inlier_A = _data(nx, inlier_A, type_as)
inlier_B = _data(nx, inlier_B, type_as)
inlier_P = _data(nx, inlier_P, type_as)
else:
init_R = np.eye(D)
init_t = np.zeros((D,))
inlier_A = np.zeros((4, D))
inlier_B = np.zeros((4, D))
inlier_P = np.ones((4, 1))
init_R = _data(nx, init_R, type_as)
init_t = _data(nx, init_t, type_as)
inlier_A = _data(nx, inlier_A, type_as)
inlier_B = _data(nx, inlier_B, type_as)
inlier_P = _data(nx, inlier_P, type_as)
coarse_alignment = coordsA

# Random select control points
Expand Down Expand Up @@ -389,20 +432,29 @@ def BA_align(
R = _identity(nx, D, type_as)
minGeneDistMat = nx.min(GeneDistMat, 1)
# Automatically determine the value of beta2
beta2 = minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5
beta2_end = nx.max(minGeneDistMat) / 5
beta2 = (
minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5
if beta2 is None
else _data(nx, data=beta2, type_as=type_as)
)
beta2_end = nx.max(minGeneDistMat) / 5 if beta2_end is None else _data(nx, data=beta2_end, type_as=type_as)
del minGeneDistMat
if sub_sample:
del sub_X_A, sub_X_B, GeneDistMat
# The value of beta2 becomes progressively larger
beta2 = nx.maximum(beta2, _data(nx, 1e-2, type_as))
beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50))

print("beta2: {} --> {}".format(beta2, beta2_end))
# Use smaller spatial variance to reduce tails
outlier_variance = 1
max_outlier_variance = partial_robust_level # 20
outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2))

if use_label_prior:
label_dist = label_transfer[labelA, :][:, labelB]
else:
label_dist = None

if SVI_mode:
SVI_deacy = _data(nx, 10.0, type_as)
# Select a random subset of data
Expand All @@ -417,6 +469,7 @@ def BA_align(
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = SpatialDistMat[:, randIdx] # NA x batch_size
randlabel_dist = label_dist[:, randIdx] if use_label_prior else None
Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0
SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K
PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D
Expand Down Expand Up @@ -451,6 +504,7 @@ def BA_align(
GeneDistMat=randGeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
label_dist=randlabel_dist,
)
else:
P, spatial_P, sigma2_P = get_P(
Expand All @@ -464,6 +518,7 @@ def BA_align(
GeneDistMat=GeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
label_dist=label_dist,
)

if iter > 5:
Expand Down Expand Up @@ -529,7 +584,10 @@ def BA_align(
)

# Update R()
lambdaReg = 1e0 * Sp / nx.sum(inlier_P)
if nn_init:
lambdaReg = partial_robust_level * 1e0 * Sp / nx.sum(inlier_P)
else:
lambdaReg = 0
if SVI_mode:
PXA, PVA, PXB = (
_dot(nx)(K_NA, coordsA)[None, :],
Expand Down Expand Up @@ -610,6 +668,7 @@ def BA_align(
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = cal_dist(XAHat, randcoordsB)
randlabel_dist = label_dist[:, randIdx] if use_label_prior else None
empty_cache(device=device)

# full data
Expand All @@ -619,12 +678,14 @@ def BA_align(
XnB=coordsB,
X_A=X_A,
X_B=X_B,
sigma2=sigma2,
sigma2=sigma2 if sigma2_end is None else _data(nx, sigma2_end, type_as),
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_dist=label_dist,
dissimilarity=dissimilarity,
)
# Get optimal Rigid transformation
optimal_RnA, optimal_R, optimal_t = get_optimal_R(
Expand Down
13 changes: 12 additions & 1 deletion spateo/alignment/methods/morpho_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_P_sparse(
labelB: Optional[pd.Series] = None,
label_transfer_prior: Optional[dict] = None,
top_k: int = 1024,
dissimilarity: str = "kl",
):
assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features."
assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length."
Expand Down Expand Up @@ -118,6 +119,7 @@ def get_P_sparse(
col_mul=(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
batch_capacity=batch_capacity,
top_k=top_k,
dissimilarity=dissimilarity,
)

K_NA = P.sum(1).to_dense()
Expand Down Expand Up @@ -151,6 +153,7 @@ def BA_align_sparse(
iter_key_added: Optional[str] = "iter_spatial",
vecfld_key_added: Optional[str] = "VecFld_morpho",
layer: str = "X",
use_rep: Optional[str] = None,
dissimilarity: str = "kl",
max_iter: int = 200,
lambdaVF: Union[int, float] = 1e2,
Expand All @@ -173,6 +176,7 @@ def BA_align_sparse(
batch_size: int = 1024,
use_sparse: bool = True,
pre_compute_dist: bool = False,
batch_capacity: int = 1,
) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]:
empty_cache(device=device)
# Preprocessing and extract the spatial and expression information
Expand All @@ -196,6 +200,7 @@ def BA_align_sparse(
dtype=dtype,
device=device,
verbose=verbose,
use_rep=use_rep,
)
coordsA, coordsB = spatial_coords[1], spatial_coords[0]
X_A, X_B = exp_matrices[1], exp_matrices[0]
Expand Down Expand Up @@ -351,6 +356,8 @@ def BA_align_sparse(
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
dissimilarity=dissimilarity,
batch_capacity=batch_capacity,
)
else:
P, assignment_results = get_P_sparse(
Expand All @@ -367,6 +374,8 @@ def BA_align_sparse(
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
dissimilarity=dissimilarity,
batch_capacity=batch_capacity,
)

# update temperature
Expand Down Expand Up @@ -538,7 +547,9 @@ def BA_align_sparse(
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
top_k=32,
top_k=512,
dissimilarity=dissimilarity,
batch_capacity=batch_capacity,
)
# Get optimal Rigid transformation
optimal_RnA, optimal_R, optimal_t = get_optimal_R_sparse(
Expand Down
Loading

0 comments on commit 2521a7c

Please sign in to comment.