diff --git a/tomotwin/modules/inference/embedor.py b/tomotwin/modules/inference/embedor.py index ae81957..c8d40cc 100644 --- a/tomotwin/modules/inference/embedor.py +++ b/tomotwin/modules/inference/embedor.py @@ -557,6 +557,14 @@ def load_weights_(self): self.model = torch.compile(self.model, mode="reduce-overhead") self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank]) + def get_unique_indicis(self, a): + # got this idea from https://github.com/pytorch/pytorch/issues/36748 + unique, inverse = torch.unique(a, sorted=True, return_inverse=True) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) + return perm + def embed(self, volume_data: VolumeDataset) -> np.array: """Calculates the embeddings. The volumes showed have the dimension NxBSxBSxBS""" @@ -564,7 +572,7 @@ def embed(self, volume_data: VolumeDataset) -> np.array: torch.backends.cudnn.benchmark = True dataset = TorchVolumeDataset(volumes=volume_data) - sampler_data = torch.utils.data.DistributedSampler(dataset, rank=self.rank, shuffle=True) + sampler_data = torch.utils.data.DistributedSampler(dataset, rank=self.rank, shuffle=False) volume_loader = DataLoader( dataset=dataset, batch_size=self.batchsize, @@ -604,6 +612,8 @@ def embed(self, volume_data: VolumeDataset) -> np.array: if self.rank == 0: items_indicis = torch.cat(items_gather_list) + unique_elements = self.get_unique_indicis(items_indicis) + items_indicis = items_indicis[unique_elements] else: items_indicis = None @@ -622,6 +632,7 @@ def embed(self, volume_data: VolumeDataset) -> np.array: if self.rank == 0: embeddings = torch.cat(embeddings_gather_list) + embeddings = embeddings[unique_elements] embeddings = embeddings[torch.argsort(items_indicis)] # sort embeddings after gathering embeddings = embeddings.data.cpu().numpy() else: