Skip to content

Commit

Permalink
update triang
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Jun 25, 2023
1 parent d6a635b commit dbbd2b9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
36 changes: 20 additions & 16 deletions tardis_pytorch/dist_pytorch/sparse_model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,30 +91,34 @@ def forward(self, x: torch.tensor, indices: list) -> Union[torch.tensor, list]:

if self.axis == 1: # Row-wise update
i_element = [indices[1][1][i].copy() + [i] for i in range(len(indices[1][1]))]
for i in range(len(indices[1][1])):
i_el = i_element[i]
i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element
for j in i_el]
j_th = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]])) for i_el
in i_element for j in i_el]
i_element = [i for s in i_element for i in s]

list_ith = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for _id, j in
enumerate(i_el)]
list_jth = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]]))
for _id, j in enumerate(i_el)]
shape_ = [len(x) for x in j_th]
cumulative_sizes = np.cumsum([0] + shape_)

resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)]
df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :]
df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))]

k[:, i_el, :] = torch.stack(resoult, dim=1)
k[:, i_element, :] = torch.cat(df_ij)
else:
i_element = [indices[2][1][i].copy() + [i] for i in range(len(indices[2][1]))]
for i in range(len(indices[2][1])):
i_el = i_element[i]
i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element
for j in i_el]
j_th = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]])) for i_el
in i_element for j in i_el]
i_element = [i for s in i_element for i in s]

list_ith = [list(compress(i_el, [True if k in indices[2][1][j] else False for k in i_el])) for _id, j in
enumerate(i_el)]
list_jth = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]]))
for _id, j in enumerate(i_el)]
shape_ = [len(x) for x in j_th]
cumulative_sizes = np.cumsum([0] + shape_)

resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)]
df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :]
df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))]

k[:, i_el, :] = torch.stack(resoult, dim=1)
k[:, i_element, :] = torch.cat(df_ij)

# if self.axis == 1: # Row-wise
# idx[0] = idx[0].reshape(org_shape[1], self.k, 2)
Expand Down
9 changes: 9 additions & 0 deletions tardis_pytorch/tardis/csv_to_am.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#######################################################################
# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation #
# #
# New York Structural Biology Center #
# Simons Machine Learning Center #
# #
# Robert Kiewisz, Tristan Bepler #
# MIT License 2021 - 2023 #
#######################################################################

0 comments on commit dbbd2b9

Please sign in to comment.