Skip to content

Commit

Permalink
support sparse matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
RonZeira committed Jul 31, 2021
1 parent cf2abe5 commit fd05408
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 9 deletions.
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
paste_output/
.DS_Store
AD.ipynb
BC.ipynb
cortex.ipynb
DH.ipynb
sample_data/adpolb*
sample_data/dh*
sample_data/DLPFC
sample_data/151507*
sample_data/Developmental*
sample_data/expr*
sample_data/RA-and-SP*
.ipynb_checkpoints/
!Tutorial-checkpoint.ipynp
ST_sim.ipynb
src/__pycache__/visualization.cpython-37.pyc
__pycache__/paste.cpython-37.pyc
src/helper_ron.py
src/PASTE_ron.py
cortex_ron.ipynb
2 changes: 1 addition & 1 deletion Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"source": [
"import math\n",
"import time\n",
"import pandas as pds\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scanpy as sc\n",
"import seaborn as sns\n",
Expand Down
Binary file modified src/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
15 changes: 8 additions & 7 deletions src/paste/PASTE.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import ot
from sklearn.decomposition import NMF
from scipy.spatial import distance_matrix
import scipy
from numpy import linalg as LA
from .helper import kl_divergence, intersect
from .helper import kl_divergence, intersect, to_dense_array

def pairwise_align(sliceA, sliceB, alpha = 0.1, G_init = None, a_distribution = None, b_distribution = None, norm = False, numItermax = 200, return_obj = False, verbose = False, **kwargs):
"""
Expand Down Expand Up @@ -33,8 +34,8 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, G_init = None, a_distribution =

D_A = distance_matrix(sliceA.obsm['spatial'], sliceA.obsm['spatial'])
D_B = distance_matrix(sliceB.obsm['spatial'], sliceB.obsm['spatial'])
s_A = sliceA.X + 0.01
s_B = sliceB.X + 0.01
s_A = to_dense_array(sliceA.X) + 0.01
s_B = to_dense_array(sliceB.X) + 0.01
M = kl_divergence(s_A, s_B)

if a_distribution is None:
Expand All @@ -48,8 +49,8 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, G_init = None, a_distribution =
b = b_distribution

if norm:
D1 /= D1[D1>0].min().min()
D2 /= D2[D2>0].min().min()
D_A /= D_A[D_A>0].min().min()
D_B /= D_B[D_B>0].min().min()

if G_init is None:
pi, logw = ot.gromov.fused_gromov_wasserstein(M, D_A, D_B, a, b, loss_fun='square_loss', alpha= alpha, log=True, numItermax=numItermax,verbose=verbose)
Expand Down Expand Up @@ -99,7 +100,7 @@ def center_align(A, slices, lmbda, alpha = 0.1, n_components = 15, threshold = 0
W = model.fit_transform(A.X)
else:
pis = pis_init
W = model.fit_transform(A.shape[0]*sum([lmbda[i]*np.dot(pis[i], slices[i].X) for i in range(len(slices))]))
W = model.fit_transform(A.shape[0]*sum([lmbda[i]*np.dot(pis[i], to_dense_array(slices[i].X)) for i in range(len(slices))]))
H = model.components_
center_coordinates = A.obsm['spatial']

Expand Down Expand Up @@ -146,7 +147,7 @@ def center_ot(W, H, slices, center_coordinates, common_genes, alpha, norm = Fals
def center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, verbose = False):
print('Solving Center Mapping NMF Problem:')
n = W.shape[0]
B = n*sum([lmbda[i]*np.dot(pis[i], slices[i].X) for i in range(len(slices))])
B = n*sum([lmbda[i]*np.dot(pis[i], to_dense_array(slices[i].X)) for i in range(len(slices))])
model = NMF(n_components=n_components, solver = 'mu', beta_loss = 'kullback-leibler', init='random', random_state = random_seed, verbose = verbose)
W_new = model.fit_transform(B)
H_new = model.components_
Expand Down
Binary file modified src/paste/__pycache__/PASTE.cpython-38.pyc
Binary file not shown.
Binary file modified src/paste/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified src/paste/__pycache__/helper.cpython-38.pyc
Binary file not shown.
Binary file modified src/paste/__pycache__/visualization.cpython-38.pyc
Binary file not shown.
4 changes: 3 additions & 1 deletion src/paste/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,6 @@ def match_spots_using_spatial_heuristic(X,Y,use_ot=True):
pi[row_ind, col_ind] = 1/max(n1,n2)
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
return pi
return pi

to_dense_array = lambda X: np.array(X.todense()) if isinstance(X,scipy.sparse.csr.spmatrix) else X

0 comments on commit fd05408

Please sign in to comment.