diff --git a/lightly/embedding/embedding.py b/lightly/embedding/embedding.py index 08d5fc484..edf41d7df 100644 --- a/lightly/embedding/embedding.py +++ b/lightly/embedding/embedding.py @@ -106,11 +106,12 @@ def embed(self, self.model.eval() embeddings, labels, filenames = None, None, [] + dataset = dataloader.dataset if lightly._is_prefetch_generator_available(): dataloader = BackgroundGenerator(dataloader, max_prefetch=3) pbar = tqdm( - total=len(dataloader.dataset), + total=len(dataset), unit='imgs' ) @@ -157,7 +158,7 @@ def embed(self, embeddings = embeddings.cpu().numpy() labels = labels.cpu().numpy() - sorted_filenames = dataloader.dataset.get_filenames() + sorted_filenames = dataset.get_filenames() sorted_embeddings = sort_items_by_keys( filenames, embeddings, sorted_filenames )