Skip to content

Commit

Permalink
check for duplicate embedded elements
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenwagner committed Oct 30, 2023
1 parent 5b04733 commit 14cd845
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tomotwin/modules/inference/embedor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,14 +557,22 @@ def load_weights_(self):
self.model = torch.compile(self.model, mode="reduce-overhead")
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank])

def get_unique_indicis(self, a):
# got this idea from https://github.com/pytorch/pytorch/issues/36748
unique, inverse = torch.unique(a, sorted=True, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
return perm

Check warning on line 566 in tomotwin/modules/inference/embedor.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/inference/embedor.py#L562-L566

Added lines #L562 - L566 were not covered by tests

def embed(self, volume_data: VolumeDataset) -> np.array:
"""Calculates the embeddings. The volumes showed have the dimension NxBSxBSxBS"""

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

dataset = TorchVolumeDataset(volumes=volume_data)
sampler_data = torch.utils.data.DistributedSampler(dataset, rank=self.rank, shuffle=True)
sampler_data = torch.utils.data.DistributedSampler(dataset, rank=self.rank, shuffle=False)

Check warning on line 575 in tomotwin/modules/inference/embedor.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/inference/embedor.py#L575

Added line #L575 was not covered by tests
volume_loader = DataLoader(
dataset=dataset,
batch_size=self.batchsize,
Expand Down Expand Up @@ -604,6 +612,8 @@ def embed(self, volume_data: VolumeDataset) -> np.array:

if self.rank == 0:
items_indicis = torch.cat(items_gather_list)
unique_elements = self.get_unique_indicis(items_indicis)
items_indicis = items_indicis[unique_elements]

Check warning on line 616 in tomotwin/modules/inference/embedor.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/inference/embedor.py#L615-L616

Added lines #L615 - L616 were not covered by tests
else:
items_indicis = None

Expand All @@ -622,6 +632,7 @@ def embed(self, volume_data: VolumeDataset) -> np.array:

if self.rank == 0:
embeddings = torch.cat(embeddings_gather_list)
embeddings = embeddings[unique_elements]

Check warning on line 635 in tomotwin/modules/inference/embedor.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/inference/embedor.py#L635

Added line #L635 was not covered by tests
embeddings = embeddings[torch.argsort(items_indicis)] # sort embeddings after gathering
embeddings = embeddings.data.cpu().numpy()
else:
Expand Down

0 comments on commit 14cd845

Please sign in to comment.