Skip to content

Commit

Permalink
bit faster sparse triangulation
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Jun 23, 2023
1 parent 3096736 commit d6a635b
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions tardis_pytorch/dist_pytorch/sparse_model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,35 +90,31 @@ def forward(self, x: torch.tensor, indices: list) -> Union[torch.tensor, list]:
k = torch.zeros_like(a)

if self.axis == 1: # Row-wise update
i_element = [indices[1][1][i].copy() + [i] for i in range(len(indices[1][1]))]
for i in range(len(indices[1][1])):
i_element = indices[1][1][i].copy()
i_element.append(i)
i_el = i_element[i]

for _id, j in enumerate(indices[1][1][i]):
fil = [True if m in i_element else False for m in indices[1][1][j]]
j_th = list(compress(indices[1][1][j], fil))
list_ith = [list(compress(i_el, [True if k in indices[1][1][j] else False for k in i_el])) for _id, j in
enumerate(i_el)]
list_jth = [list(compress(indices[1][1][j], [True if k in i_el else False for k in indices[1][1][j]]))
for _id, j in enumerate(i_el)]

fil = [True if m in indices[1][1][j] else False for m in i_element]
i_th = list(compress(i_element, fil))
resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)]

k[:, indices[1][0][i][_id], :] = torch.sum(
a[:, i_th, :] * b[:, j_th, :], dim=1
)
k[:, i_el, :] = torch.stack(resoult, dim=1)
else:
i_element = [indices[2][1][i].copy() + [i] for i in range(len(indices[2][1]))]
for i in range(len(indices[2][1])):
i_element = indices[2][1][i].copy()
i_element.append(i)
i_el = i_element[i]

for _id, j in enumerate(indices[2][1][i]):
fil = [True if m in i_element else False for m in indices[2][1][j]]
j_th = list(compress(indices[2][1][j], fil))
list_ith = [list(compress(i_el, [True if k in indices[2][1][j] else False for k in i_el])) for _id, j in
enumerate(i_el)]
list_jth = [list(compress(indices[2][1][j], [True if k in i_el else False for k in indices[2][1][j]]))
for _id, j in enumerate(i_el)]

fil = [True if m in indices[2][1][j] else False for m in i_element]
i_th = list(compress(i_element, fil))
resoult = [torch.sum(a[:, j, :] * b[:, i, :], dim=1) for j, i in zip(list_jth, list_ith)]

k[:, indices[2][0][i][_id], :] = torch.sum(
a[:, i_th, :] * b[:, j_th, :], dim=1
)
k[:, i_el, :] = torch.stack(resoult, dim=1)

# if self.axis == 1: # Row-wise
# idx[0] = idx[0].reshape(org_shape[1], self.k, 2)
Expand Down

0 comments on commit d6a635b

Please sign in to comment.