Skip to content

Commit

Permalink
Merge pull request #189 from Sichao25/refactor
Browse files Browse the repository at this point in the history
Reformat
  • Loading branch information
Xiaojieqiu authored Jun 30, 2023
2 parents cb9ec31 + 332db9d commit d9c3b67
Show file tree
Hide file tree
Showing 25 changed files with 0 additions and 40 deletions.
1 change: 0 additions & 1 deletion spateo/alignment/deformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def grid_deformation(
dtype: str = "float64",
device: str = "cpu",
):

# Check the number of lines
grid_num = np.asarray([20, 20]) if grid_num is None else grid_num

Expand Down
1 change: 0 additions & 1 deletion spateo/alignment/methods/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def paste_pairwise_align(


def center_NMF(n_components, random_seed, dissimilarity="kl"):

if dissimilarity.lower() == "euclidean" or dissimilarity.lower() == "euc":
model = NMF(n_components=n_components, init="random", random_state=random_seed)
else:
Expand Down
1 change: 0 additions & 1 deletion spateo/alignment/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,6 @@ def PCA_recover(
V_new_basis: Union[np.ndarray, torch.Tensor],
mean_data_mat: Union[np.ndarray, torch.Tensor],
) -> Union[np.ndarray, torch.Tensor]:

nx = ot.backend.get_backend(projected_data)
return nx.einsum("ij,jk->ik", projected_data, V_new_basis.t()) + mean_data_mat

Expand Down
1 change: 0 additions & 1 deletion spateo/alignment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def get_optimal_mapping_relationship(
X_max_index = np.argwhere((pi.T == pi.T.max(axis=0)).T)
Y_max_index = np.argwhere(pi == pi.max(axis=0))
if not keep_all:

values, counts = np.unique(X_max_index[:, 0], return_counts=True)
x_index_unique, x_index_repeat = values[counts == 1], values[counts != 1]
X_max_index_unique = X_max_index[np.isin(X_max_index[:, 0], x_index_unique)]
Expand Down
1 change: 0 additions & 1 deletion spateo/plotting/static/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def multi_slices(
save_kwargs: Optional[dict] = None,
**kwargs,
):

# Check slices object.
if isinstance(slices, list):
adatas = [s.copy() for s in slices]
Expand Down
1 change: 0 additions & 1 deletion spateo/plotting/static/dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def __init__(
norm: Optional[Normalize] = None,
**kwargs,
):

# Default plotting parameters:
config_spateo_rcParams()
set_pub_style()
Expand Down
3 changes: 0 additions & 3 deletions spateo/plotting/static/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def plot_connections(

for label_1 in range(spatial_connections.shape[0]):
for label_2 in range(spatial_connections.shape[1]):

if label_1 <= label_2:

for triangle in [left_triangle, right_triangle]:
center = np.array((label_1, label_2))[np.newaxis, :]
scale_factor = spatial_connections[label_1, label_2] / spatial_connections_max
Expand Down Expand Up @@ -307,7 +305,6 @@ def plot_connections(

for label_1 in range(expr_connections.shape[0]):
for label_2 in range(expr_connections.shape[1]):

if label_1 <= label_2:
for triangle in [left_triangle, right_triangle]:
center = np.array((label_1, label_2))[np.newaxis, :]
Expand Down
1 change: 0 additions & 1 deletion spateo/plotting/static/three_d_plot/three_dims_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def wrap_to_plotter(
vertical=True,
)
if not (legend_kwargs is None):

lg_kwargs.update((k, legend_kwargs[k]) for k in lg_kwargs.keys() & legend_kwargs.keys())

add_legend(
Expand Down
2 changes: 0 additions & 2 deletions spateo/plotting/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,6 @@ def interactive(
colors = matplotlib.colors.rgb2hex(plt.get_cmap(cmap)(0.5))

if points.shape[0] <= width * height // 10:

if hover_data is not None:
tooltip_dict = {}
for col_name in hover_data:
Expand Down Expand Up @@ -1708,7 +1707,6 @@ def save_return_show_fig_utils(
return_all: bool,
return_all_list: Union[List, Tuple, None],
) -> Optional[Tuple]:

from ...configuration import reset_rcParams
from ...tools.utils import update_dict

Expand Down
1 change: 0 additions & 1 deletion spateo/preprocessing/auxseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _threshold_gradient_image(self):
self.edges = self.edges.astype(float)

def _compute_graph(self):

try:
from dijkstar import Graph
except ImportError:
Expand Down
1 change: 0 additions & 1 deletion spateo/segmentation/simulation_evaluation/allocate_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def shift_cells(cells, labels, max_iter, seed, shift_length=10):
center_shifts = np.random.randint(-shift_length, shift_length + 1, 2 * max_iter + 2).reshape(-1, 2)

while deal_list:

c += 1
one = deal_list.pop(0)
labels_tmp = labels.copy()
Expand Down
1 change: 0 additions & 1 deletion spateo/svg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def cal_euclidean_distance(
min_dis_cutoff: float = np.inf,
max_dis_cutoff: float = np.inf,
) -> AnnData:

dyn.tl.neighbors(
adata,
X_data=adata.obsm[layer],
Expand Down
1 change: 0 additions & 1 deletion spateo/tdr/models/utilities/label_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def add_model_labels(
labels = np.asarray(labels).flatten()

if not np.issubdtype(labels.dtype, np.number):

cu_arr = np.sort(np.unique(labels), axis=0).astype(object)
raw_labels_hex = labels.copy().astype(object)
raw_labels_alpha = labels.copy().astype(object)
Expand Down
2 changes: 0 additions & 2 deletions spateo/tdr/models/utilities/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def _scale_model_by_distance(
distance: Union[int, float, list, tuple] = 1,
scale_center: Union[list, tuple] = None,
) -> DataSet:

# Check the distance.
distance = distance if isinstance(distance, (tuple, list)) else [distance] * 3
if len(distance) != 3:
Expand All @@ -117,7 +116,6 @@ def _scale_model_by_scale_factor(
scale_factor: Union[int, float, list, tuple] = 1,
scale_center: Union[list, tuple] = None,
) -> DataSet:

# Check the scaling factor.
scale_factor = scale_factor if isinstance(scale_factor, (tuple, list)) else [scale_factor] * 3
if len(scale_factor) != 3:
Expand Down
4 changes: 0 additions & 4 deletions spateo/tdr/widgets/deep_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def train(
# The optimizers for Neural Nets

if self.input_network_dim < self.data_dim:

self.optimizer = optim.Adam(
list(self.A.parameters()) + list(self.h.parameters()) + list(self.B.parameters()),
lr=data_lr,
Expand Down Expand Up @@ -212,13 +211,11 @@ def train(

# LET'S TRAIN!!
for iter in range(max_iter):

###############################
### MAIN FLOW PASS ###
###############################

if self.data_sampler is not None:

# Set the gradients to zero
self.h.zero_grad()
if self.input_network_dim < self.data_dim:
Expand Down Expand Up @@ -262,7 +259,6 @@ def train(
#########################

if self.input_network_dim < self.data_dim:

# Set the gradients to zero
self.A.zero_grad(), self.B.zero_grad()

Expand Down
7 changes: 0 additions & 7 deletions spateo/tdr/widgets/interpolation_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(
hidden_layers=1,
activation_function=torch.nn.functional.leaky_relu,
):

super(A, self).__init__() # Call to the super-class is necessary

self.f = activation_function
Expand All @@ -36,7 +35,6 @@ def __init__(
# torch.nn.init.normal_(self.layer3.weight, std=.02)

def forward(self, inp):

out = self.f(self.layer1(inp), negative_slope=0.2)
out = self.f(self.hidden_layers(out), negative_slope=0.2)
out = self.outlayer(out)
Expand All @@ -57,7 +55,6 @@ def __init__(
hidden_layers=3,
activation_function=torch.nn.functional.leaky_relu,
):

super(B, self).__init__() # Call to the super-class is necessary

self.f = activation_function
Expand All @@ -77,7 +74,6 @@ def __init__(
# torch.nn.init.normal_(self.layer3.weight, std=.02)

def forward(self, inp):

out = self.f(self.layer1(inp), negative_slope=0.2)
out = self.f(self.hidden_layers(out), negative_slope=0.2)
out = self.outlayer(out)
Expand Down Expand Up @@ -183,7 +179,6 @@ def __init__(
)

def forward(self, inp):

out = (
self.f(self.first_omega_0 * self.layer1(inp))
if self.sirens
Expand All @@ -197,7 +192,6 @@ def forward(self, inp):

class MainFlow(torch.nn.Module):
def __init__(self, h, A=None, B=None, enforce_positivity=False):

super(MainFlow, self).__init__()

self.A = A
Expand All @@ -206,7 +200,6 @@ def __init__(self, h, A=None, B=None, enforce_positivity=False):
self.enforce_positivity = enforce_positivity

def forward(self, t, x, freeze=None):

x_low = self.A(x) if self.A is not None else x
e_low = self.h.forward(x_low)
e_hat = self.B(e_low) if self.B is not None else e_low
Expand Down
1 change: 0 additions & 1 deletion spateo/tdr/widgets/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def find_model_outline_planes(model) -> dict:


def find_intersection(model, vec, center, plane):

# Normalize the vector
normal = vec / np.linalg.norm(vec)
normal_x, normal_y, normal_z = normal
Expand Down
1 change: 0 additions & 1 deletion spateo/tdr/widgets/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def DDRTree(
# main loop
objs = []
for iter in range(maxIter):

# Kruskal method to find optimal B
distsqMU = csr_matrix(sqdist(Y, Y)).toarray()
stree = minimum_spanning_tree(np.tril(distsqMU)).toarray()
Expand Down
1 change: 0 additions & 1 deletion spateo/tools/ST_regression/generalized_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def __init__(
theta: float = 1.0,
verbose: bool = True,
):

self.logger = lm.get_main_logger()
allowable_dists = ["gaussian", "poisson", "softplus", "neg-binomial", "gamma"]
if distr not in allowable_dists:
Expand Down
1 change: 0 additions & 1 deletion spateo/tools/architype.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def archetypes(
num_clusters: int = 5,
layer: Union[str, None] = None,
) -> np.ndarray:

"""Identify archetypes from the anndata object.
Args:
Expand Down
1 change: 0 additions & 1 deletion spateo/tools/cluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def compute_pca_components(

# Whether to save the image of PCA curve and inflection point.
if save_curve_img is not None:

kl.plot_knee()
plt.tight_layout()
plt.savefig(save_curve_img, dpi=100)
Expand Down
2 changes: 0 additions & 2 deletions spateo/tools/cluster_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def vi_plot(
sub_adata: subset of adata.
"""
if group and group_color:

df = pd.DataFrame()
df["group_ID"] = self.adata.obs_names
df["labels"] = self.adata.obs[group].values
Expand Down Expand Up @@ -96,7 +95,6 @@ def vi_plot(
)

def selection_fn(trace, points, selector):

t.data[0].cells.values = [
df.loc[points.point_inds][col] for col in ["group_ID", "labels", "spatial_0", "spatial_1"]
]
Expand Down
2 changes: 0 additions & 2 deletions spateo/tools/coarse_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def procrustes(
T = np.dot(V, U.T)

if reflection != "best":

# does the current solution use a reflection?
have_reflection = np.linalg.det(T) < 0

Expand All @@ -95,7 +94,6 @@ def procrustes(
traceTA = s.sum()

if scaling:

# optimum scaling of Y
b = traceTA * normX / normY

Expand Down
1 change: 0 additions & 1 deletion spateo/tools/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def row_normalize(
data = graph.data

for start_ptr, end_ptr in zip(graph.indptr[:-1], graph.indptr[1:]):

row_sum = data[start_ptr:end_ptr].sum()

if row_sum != 0:
Expand Down
1 change: 0 additions & 1 deletion spateo/tools/live_wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _threshold_gradient_image(self):
self.edges = self.edges.astype(float)

def _compute_graph(self):

try:
from dijkstar import Graph
except ImportError:
Expand Down

0 comments on commit d9c3b67

Please sign in to comment.