Skip to content

Commit

Permalink
SparseDIST update
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Aug 1, 2023
1 parent 8e86f5e commit 585a8ef
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 81 deletions.
4 changes: 2 additions & 2 deletions tardis_pytorch/dist_pytorch/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,8 @@ def __getitem__(self, i: int):
cls=cls_idx,
)

if self.sparse:
graph_idx = [g.to_sparse() for g in graph_idx]
# if self.sparse:
# graph_idx = [g.to_sparse() for g in graph_idx]

if self.benchmark:
# Output file_name, raw_coord, edge_f, node_f, graph, node_idx, node_class
Expand Down
18 changes: 10 additions & 8 deletions tardis_pytorch/dist_pytorch/sparse_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn

import numpy as np
from tardis_pytorch.dist_pytorch.sparse_model.embedding import SparseEdgeEmbeddingV3
from tardis_pytorch.dist_pytorch.sparse_model.embedding import SparseEdgeEmbeddingV4
from tardis_pytorch.dist_pytorch.sparse_model.layers import SparseDistStack


Expand All @@ -33,8 +33,10 @@ def __init__(
n_out=1,
edge_dim=128,
num_layers=6,
knn=8,
coord_embed_sigma=1.0,
predict=False,
device='cpu'
):
"""
Initializes the SparseDIST.
Expand All @@ -51,11 +53,12 @@ def __init__(
self.n_out = n_out
self.edge_dim = edge_dim
self.num_layers = num_layers
self.knn = knn
self.edge_sigma = coord_embed_sigma
self.predict = predict

self.coord_embed = SparseEdgeEmbeddingV3(
n_out=self.edge_dim, sigma=self.edge_sigma
self.coord_embed = SparseEdgeEmbeddingV4(
n_out=self.edge_dim, sigma=self.edge_sigma, knn=self.knn, _device=device,
)

self.layers = SparseDistStack(
Expand All @@ -80,7 +83,7 @@ def embed_input(self, coords: torch.tensor) -> torch.tensor:

return x, idx

def forward(self, coords: torch.tensor) -> torch.tensor:
def forward(self, coords: torch.tensor, idx=None) -> torch.tensor:
"""
Forward pass for the SparseDIST.
Expand All @@ -94,12 +97,11 @@ def forward(self, coords: torch.tensor) -> torch.tensor:
edge, idx = self.embed_input(coords=coords) # List[Indices, Values, Shape]

# Encode throughout the transformer layers
edge, idx = self.layers(
edge_features=edge, indices=idx
) # List[Indices, Values, Shape]
edge = self.layers(edge_features=edge, indices=idx) # List[Indices, Values, Shape]

# Predict the graph edges
edge = self.decoder(edge + edge[:, np.concatenate(idx[2][1]), :])
# edge = self.decoder(edge + edge[:, np.concatenate(idx[2][1]), :])
edge = self.decoder(edge)

if self.predict:
edge = torch.sigmoid(edge)
Expand Down
74 changes: 74 additions & 0 deletions tardis_pytorch/dist_pytorch/sparse_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import numpy as np
from typing import Union

from scipy.spatial import KDTree
import numpy as np


class SparseEdgeEmbedding(nn.Module):
"""
Expand Down Expand Up @@ -317,3 +320,74 @@ def forward(self, input_coord: torch.tensor) -> Union[torch.tensor, list]:
k_dist_range[:, id_] = torch.exp(-(_dist**2) / (i**2 * 2))

return k_dist_range.unsqueeze(0).to(device), [indices, row_indices, col_indices, _shape]


class SparseEdgeEmbeddingV4(nn.Module):
"""
Module for Sparse Edge Embedding.
This class is responsible for computing a sparse adjacency matrix
with edge weights computed using a Gaussian kernel function over
the distances between input coordinates.
"""

def __init__(self, n_out: int, sigma: list, knn: int, _device):
"""
Initializes the SparseEdgeEmbedding.
Args:
n_out (int): The number of output channels.
sigma (list): The range of sigma values for the Gaussian kernel.
"""
super().__init__()
self._range = torch.linspace(sigma[0], sigma[1], n_out)

self.knn = knn
self.n_out = n_out
self.sigma = sigma
self._device = _device

def forward(self, input_coord: np.ndarray) -> Union[torch.tensor, list]:
with torch.no_grad():
# Get all ij element from row and col
input_coord = input_coord.cpu().detach().numpy()
tree = KDTree(input_coord)
distances, indices = tree.query(input_coord, self.knn)

n = len(input_coord)
M = distances.flatten()

all_ij_id = np.array((np.repeat(np.arange(n), self.knn), indices.flatten())).T

# Row-Wise M[ij] index
row_idx = np.repeat(np.arange(0, len(M)).reshape(len(input_coord), self.knn), self.knn, axis=0).reshape(len(M), self.knn) + 1
row_idx = np.vstack((np.repeat(0, self.knn), row_idx))

# Column-Wise M[ij] index
col_idx = np.array([np.pad(c, (0, self.knn-len(c)))
if len(c) <= self.knn
else c[np.argsort(M[c-1])[:self.knn]]
for c in [np.where(all_ij_id[:, 1] == i)[0] + 1
for i in range(len(input_coord))]])

col_idx = np.repeat(col_idx, self.knn, axis=0)
col_idx = np.vstack((np.repeat(0, self.knn), col_idx))

M = torch.from_numpy(np.pad(M, (1, 0)))
# Prepare tensor for storing range of distances
k_dist_range = torch.zeros(
(len(M), len(self._range))
)

# Apply Gaussian kernel function to the top-k distances
for id_, i in enumerate(self._range):
k_dist_range[:, id_] = torch.exp(-(M**2) / (i**2 * 2))
k_dist_range[0, :] = 0

# Replace any NaN values with zero
isnan = torch.isnan(k_dist_range)
k_dist_range = torch.where(
isnan, torch.zeros_like(k_dist_range), k_dist_range
)

return k_dist_range.to(self._device), [row_idx.astype(np.int32), col_idx.astype(np.int32), (n, n), all_ij_id]
19 changes: 10 additions & 9 deletions tardis_pytorch/dist_pytorch/sparse_model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SparseDistStack(nn.Module):
is the output of the SparseDistStack.
"""

def __init__(self, pairs_dim: int, num_layers=1, ff_factor=4, knn=12):
def __init__(self, pairs_dim: int, num_layers=1, ff_factor=4):
"""
Initializes the SparseDistStack.
Expand Down Expand Up @@ -73,9 +73,9 @@ def forward(self, edge_features: torch.tensor, indices: list) -> Union[torch.ten
torch.sparse_coo_tensor: A sparse coordinate tensor representing the output from the final layer in the stack.
"""
for layer in self.layers:
edge_features, _ = layer(h_pairs=edge_features, indices=indices)
edge_features = layer(h_pairs=edge_features, indices=indices)

return edge_features, indices
return edge_features


class SparseDistLayer(nn.Module):
Expand Down Expand Up @@ -140,11 +140,11 @@ def update_edges(self, h_pairs: torch.tensor, indices: list) -> Union[torch.tens
# ToDo Convert node features to edge shape

# Update edge features
row, _ = self.row_update(x=h_pairs, indices=indices)
col, _ = self.col_update(x=h_pairs, indices=indices)
h_pairs = (h_pairs + row + col)
row = self.row_update(x=h_pairs, indices=indices)
# col = self.col_update(x=h_pairs, indices=indices)
h_pairs = h_pairs + row # + col

return h_pairs + self.pair_ffn(x=h_pairs), indices
return h_pairs + self.pair_ffn(x=h_pairs)

def forward(self, h_pairs: torch.tensor, indices: list) -> Union[torch.tensor, list]:
"""
Expand All @@ -160,5 +160,6 @@ def forward(self, h_pairs: torch.tensor, indices: list) -> Union[torch.tensor, l
# ToDo Update node features and convert to edge shape

# Update edge features
h_pairs, idx = self.update_edges(h_pairs=h_pairs, indices=indices)
return h_pairs, idx
h_pairs = self.update_edges(h_pairs=h_pairs, indices=indices)

return h_pairs
55 changes: 6 additions & 49 deletions tardis_pytorch/dist_pytorch/sparse_model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _reset_parameters(self):
nn.init.constant_(self.linear_o.weight, 0.0)
nn.init.constant_(self.linear_o.bias, 0.0)

def forward(self, x: torch.tensor, indices: list) -> Union[torch.tensor, list]:
def forward(self, x: torch.tensor, indices: list) -> torch.tensor:
"""
Forward pass for SparsTriangularUpdate.
Expand All @@ -87,57 +87,14 @@ def forward(self, x: torch.tensor, indices: list) -> Union[torch.tensor, list]:
b = torch.sigmoid(self.gate_b(x)) * self.linear_b(x)

# # Apply triangular multiplication update
k = torch.zeros_like(a)
# k = torch.zeros_like(a)

if self.axis == 1: # Row-wise update
i_element = [indices[1][1][i].copy() + [i] for i in range(len(indices[1][1]))]
i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element
for j in i_el]
j_th = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]])) for i_el
in i_element for j in i_el]
i_element = [i for s in i_element for i in s]

shape_ = [len(x) for x in j_th]
cumulative_sizes = np.cumsum([0] + shape_)

df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :]
df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))]

k[:, i_element, :] = torch.cat(df_ij)
k = torch.einsum('ik,ijk->ik', a, b[indices[0]])
else:
i_element = [indices[2][1][i].copy() + [i] for i in range(len(indices[2][1]))]
i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element
for j in i_el]
j_th = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]])) for i_el
in i_element for j in i_el]
i_element = [i for s in i_element for i in s]

shape_ = [len(x) for x in j_th]
cumulative_sizes = np.cumsum([0] + shape_)

df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :]
df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))]

k[:, i_element, :] = torch.cat(df_ij)

# if self.axis == 1: # Row-wise
# idx[0] = idx[0].reshape(org_shape[1], self.k, 2)
# a = a.reshape(1, org_shape[1], self.k, ch)
# idx[0] = idx[0].reshape(org_shape[1], self.k, 2)
# b = b.reshape(1, org_shape[1], self.k, ch)

# k = a.repeat_interleave(self.k, dim=1) * b[0, idx[0][..., 1].flatten().long(), :].unsqueeze(0)
# k = torch.sum(k, dim=2)
# k = k.reshape(1, M_len, 32)

# idx = [idx[0].reshape(org_shape[1] * self.k, 2), idx[1]]
# else: # Column-wise
# for _id in range(mm_len):
# k[0, i_id, :] = sparse_mm(a[1][:, :mm_len, :][:, i_id, :],
# b[1][:, :mm_len, :][:, i_id, :],
# a[0][:mm_len, :][i_id, :])

return torch.sigmoid(self.gate_o(x)) * self.linear_o(self.norm_o(k)), indices
k = torch.einsum('ik,ijk->ik', a, b[indices[1]])

return torch.sigmoid(self.gate_o(x)) * self.linear_o(self.norm_o(a))


def sparse_to_dense(x: list, numpy=False) -> np.ndarray:
Expand Down
5 changes: 3 additions & 2 deletions tardis_pytorch/dist_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def train_dist(
knn=model_structure["num_knn"],
coord_embed_sigma=model_structure["coord_embed_sigma"],
predict=True,
device=device
)
elif model_structure["dist_type"] == "semantic":
model = CDIST(
Expand Down Expand Up @@ -173,10 +174,10 @@ def train_dist(

"""Build training optimizer"""
if lr_scheduler:
optimizer = optim.Adam(params=model.parameters(), betas=(0.9, 0.98), eps=1e-9)
optimizer = optim.Adam(params=model.parameters(), betas=(0.9, 0.999), eps=1e-9)
else:
optimizer = optim.Adam(
params=model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9
params=model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-9
)

"""Optionally: Build learning rate scheduler"""
Expand Down
19 changes: 9 additions & 10 deletions tardis_pytorch/dist_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ def _train(self):
self.optimizer.zero_grad()

if self.node_input > 0:
edge = self.model(coords=edge, node_features=node.to(self.device))
edge, indices = self.model(coords=edge, node_features=node.to(self.device))
else:
edge = self.model(coords=edge)
edge, indices = self.model(coords=edge)

# Back-propagate
loss = self.criterion(
edge[1].T, graph[:, edge[0][1], edge[0][2]]
edge[1:, 0], graph[0, indices[3][:, 0], indices[3][:, 1]].type(torch.float32)
) # Calc. loss
loss.backward() # One backward pass
self.optimizer.step() # Update the parameters
Expand Down Expand Up @@ -204,23 +204,22 @@ def _validate(self):
with torch.no_grad():
# Predict graph
if self.node_input > 0:
edge = self.model(
edge, indices = self.model(
coords=edge, node_features=node.to(self.device)
)
else:
edge = self.model(coords=edge)
edge, indices = self.model(coords=edge)

# Calcu late validation loss
loss = self.criterion(
edge[1].T, graph[:, edge[0][1], edge[0][2]]
edge[1:, 0], graph[0, indices[3][:, 0], indices[3][:, 1]].type(torch.float32)
) # Calc. loss

# Calculate F1 metric
pred_edge = np.zeros(indices[2])
pred_edge[indices[3][:, 0], indices[3][:, 1]] = edge[1:, 0].cpu().detach().numpy()
acc, prec, recall, f1, th = eval_graph_f1(
logits=torch.sparse_coo_tensor(edge[0], edge[1], edge[2])
.to_dense()
.cpu()
.detach()[0, ..., 0],
logits=torch.from_numpy(pred_edge),
targets=graph[0, ...].cpu().detach(),
threshold=0.5,
)
Expand Down
2 changes: 1 addition & 1 deletion tardis_pytorch/utils/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def preprocess_DIST(self, id_name: str):
# Post-process predicted image patches
if self.predict in ["Filament", "Microtubule"]:
self.pc_hd, self.pc_ld = self.post_processes.build_point_cloud(
image=self.image, EDT=True, down_sampling=5
image=self.image, EDT=True, down_sampling=10
)
else:
self.pc_hd, self.pc_ld = self.post_processes.build_point_cloud(
Expand Down

0 comments on commit 585a8ef

Please sign in to comment.