diff --git a/tardis_pytorch/dist_pytorch/sparse_model/modules.py b/tardis_pytorch/dist_pytorch/sparse_model/modules.py index 05035d2d..97a64034 100644 --- a/tardis_pytorch/dist_pytorch/sparse_model/modules.py +++ b/tardis_pytorch/dist_pytorch/sparse_model/modules.py @@ -91,30 +91,34 @@ def forward(self, x: torch.tensor, indices: list) -> Union[torch.tensor, list]: if self.axis == 1: # Row-wise update i_element = [indices[1][1][i].copy() + [i] for i in range(len(indices[1][1]))] - for i in range(len(indices[1][1])): - i_el = i_element[i] + i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element + for j in i_el] + j_th = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]])) for i_el + in i_element for j in i_el] + i_element = [i for s in i_element for i in s] - list_ith = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for _id, j in - enumerate(i_el)] - list_jth = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]])) - for _id, j in enumerate(i_el)] + shape_ = [len(x) for x in j_th] + cumulative_sizes = np.cumsum([0] + shape_) - resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)] + df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :] + df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))] - k[:, i_el, :] = torch.stack(resoult, dim=1) + k[:, i_element, :] = torch.cat(df_ij) else: i_element = [indices[2][1][i].copy() + [i] for i in range(len(indices[2][1]))] - for i in range(len(indices[2][1])): - i_el = i_element[i] + i_th = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for i_el in i_element + for j in i_el] + j_th = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]])) for i_el + in i_element for j in i_el] + i_element = [i for s in i_element for i in s] - list_ith = [list(compress(i_el, [True if k in indices[2][1][j] else False for k in i_el])) for _id, j in - enumerate(i_el)] - list_jth = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]])) - for _id, j in enumerate(i_el)] + shape_ = [len(x) for x in j_th] + cumulative_sizes = np.cumsum([0] + shape_) - resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)] + df_ij = a[:, torch.from_numpy(np.concatenate(i_th)), :] * b[:, torch.from_numpy(np.concatenate(j_th)), :] + df_ij = [torch.sum(df_ij[:, cumulative_sizes[0]:cumulative_sizes[1]], dim=1) for i in range(len(shape_))] - k[:, i_el, :] = torch.stack(resoult, dim=1) + k[:, i_element, :] = torch.cat(df_ij) # if self.axis == 1: # Row-wise # idx[0] = idx[0].reshape(org_shape[1], self.k, 2) diff --git a/tardis_pytorch/tardis/csv_to_am.py b/tardis_pytorch/tardis/csv_to_am.py new file mode 100644 index 00000000..b4e69905 --- /dev/null +++ b/tardis_pytorch/tardis/csv_to_am.py @@ -0,0 +1,9 @@ +####################################################################### +# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # +# # +# New York Structural Biology Center # +# Simons Machine Learning Center # +# # +# Robert Kiewisz, Tristan Bepler # +# MIT License 2021 - 2023 # +#######################################################################