Skip to content

Commit

Permalink
update for embeddign
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Jun 27, 2023
1 parent dbbd2b9 commit afd1259
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tardis_pytorch/dist_pytorch/sparse_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,14 @@ def forward(self, input_coord: torch.tensor) -> Union[torch.tensor, list]:
_dist[torch.where(mask)] = 0
_dist.fill_diagonal_(0)

indices = torch.where(_dist > 0)
indices = torch.where(_dist > 0) # every ij elemtent from distance embedding

"""[2-3] Get Col/Row wise indices for ij element"""
# Columns list
unique_elements, inverse_indices = torch.unique(
indices[0], return_inverse=True
)
# Number of elements per row
counts = torch.bincount(inverse_indices)

# get row-wise indices
Expand Down

0 comments on commit afd1259

Please sign in to comment.