Skip to content

Commit

Permalink
Merge pull request #7 from Angtian/Angtian-patch-3
Browse files Browse the repository at this point in the history
Update gather_features.py
  • Loading branch information
Angtian authored Sep 4, 2023
2 parents 8698129 + 2a81fd0 commit 3855017
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cu_layers/CuNeMo/gather_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def gather_features(features, weights, sample_indexs, mesh_n_list):
assert features.dim() == 3 and weights.dim() == 2 and sample_indexs.dim() == 2
total_verts = mesh_n_list.sum()

assert 0 <= sample_indexs.max() < mesh_n_list.shape[0]
assert -1 <= sample_indexs.max() < mesh_n_list.shape[0]
assert sample_indexs.dtype == mesh_n_list.dtype == torch.int32

samples_cum_num = torch.cumsum(torch.gather(mesh_n_list[None].expand(sample_indexs.shape[0], -1), dim=1, index=sample_indexs.long()), dim=1).type(torch.int32)
valid_mask = torch.logical_not(sample_indexs < 0).type(torch.int32)
samples_cum_num = torch.cumsum(valid_mask * torch.gather(mesh_n_list[None].expand(sample_indexs.shape[0], -1), dim=1, index=sample_indexs.long().clamp(min=0)), dim=1).type(torch.int32)
input_valid = samples_cum_num[:, -1]
sample_shifts = torch.cat((torch.zeros((samples_cum_num.shape[0], 1), dtype=torch.int32, device=basic_device), samples_cum_num[:, :-1]), dim=1)

Expand Down

0 comments on commit 3855017

Please sign in to comment.