diff --git a/tardis_pytorch/dist_pytorch/sparse_model/embedding.py b/tardis_pytorch/dist_pytorch/sparse_model/embedding.py index 5cfae3c7..dbc659ce 100644 --- a/tardis_pytorch/dist_pytorch/sparse_model/embedding.py +++ b/tardis_pytorch/dist_pytorch/sparse_model/embedding.py @@ -258,12 +258,14 @@ def forward(self, input_coord: torch.tensor) -> Union[torch.tensor, list]: _dist[torch.where(mask)] = 0 _dist.fill_diagonal_(0) - indices = torch.where(_dist > 0) + indices = torch.where(_dist > 0) # every ij elemtent from distance embedding """[2-3] Get Col/Row wise indices for ij element""" + # Columns list unique_elements, inverse_indices = torch.unique( indices[0], return_inverse=True ) + # Number of elements per row counts = torch.bincount(inverse_indices) # get row-wise indices