diff --git a/cu_layers/CuNeMo/gather_features.py b/cu_layers/CuNeMo/gather_features.py index d4ffb74..411ec5a 100644 --- a/cu_layers/CuNeMo/gather_features.py +++ b/cu_layers/CuNeMo/gather_features.py @@ -20,10 +20,11 @@ def gather_features(features, weights, sample_indexs, mesh_n_list): assert features.dim() == 3 and weights.dim() == 2 and sample_indexs.dim() == 2 total_verts = mesh_n_list.sum() - assert 0 <= sample_indexs.max() < mesh_n_list.shape[0] + assert -1 <= sample_indexs.max() < mesh_n_list.shape[0] assert sample_indexs.dtype == mesh_n_list.dtype == torch.int32 - samples_cum_num = torch.cumsum(torch.gather(mesh_n_list[None].expand(sample_indexs.shape[0], -1), dim=1, index=sample_indexs.long()), dim=1).type(torch.int32) + valid_mask = torch.logical_not(sample_indexs < 0).type(torch.int32) + samples_cum_num = torch.cumsum(valid_mask * torch.gather(mesh_n_list[None].expand(sample_indexs.shape[0], -1), dim=1, index=sample_indexs.long().clamp(min=0)), dim=1).type(torch.int32) input_valid = samples_cum_num[:, -1] sample_shifts = torch.cat((torch.zeros((samples_cum_num.shape[0], 1), dtype=torch.int32, device=basic_device), samples_cum_num[:, :-1]), dim=1)