Skip to content

Commit

Permalink
Merge pull request #3 from Digital-Dermatology/rerunning-optimized
Browse files Browse the repository at this point in the history
Optimized SelfClean rerunning + Updated ReadMe
  • Loading branch information
FabianGroeger96 authored Apr 4, 2024
2 parents 2b18f62 + c44b37d commit c812821
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 158 deletions.
58 changes: 49 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,63 @@
# SelfClean
# 🧼🔎 SelfClean

[**SelfClean Paper**](https://arxiv.org/abs/2305.17048) | [**Data Cleaning Protocol Paper**](https://arxiv.org/abs/2309.06961)

![SelfClean Teaser](https://github.com/Digital-Dermatology/SelfClean/raw/main/assets/SelfClean_Teaser.png)

<h2 align="center">
A holistic self-supervised data cleaning strategy to detect irrelevant samples, near duplicates, and label errors.

[![PyPI version](https://badge.fury.io/py/selfclean.svg)](https://badge.fury.io/py/selfclean)
![Contribotion](https://img.shields.io/badge/Contribution-Welcome-brightgreen)
**NOTE:** Make sure to have `git-lfs` installed before pulling the repository to ensure the pre-trained models are pulled correctly ([git-lfs install instructions](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)).

</h2>
## Installation

A holistic self-supervised data cleaning strategy to detect irrelevant samples, near duplicates, and label errors.
> Install SelfClean via [PyPI](https://pypi.org/project/selfclean/):
```python
# upgrade pip to its latest version
pip install -U pip

# install selfclean
pip install selfclean

# Alternatively, use explicit python version (XX)
python3.XX -m pip install selfclean
```

## Getting Started

You can run SelfClean in a few lines of code:

```python
from selfclean import SelfClean

selfclean = SelfClean()

# run on pytorch dataset
issues = selfclean.run_on_dataset(
dataset=copy.copy(dataset),
)
# run on image folder
issues = selfclean.run_on_image_folder(
input_path="path/to/images",
)

# get the data quality issue rankings
df_near_duplicates = issues.get_issues("near_duplicates", return_as_df=True)
df_irrelevants = issues.get_issues("irrelevants", return_as_df=True)
df_label_errors = issues.get_issues("label_errors", return_as_df=True)
```

**Examples:**
In `examples/`, we've provided some example notebooks in which you will learn how to analyze and clean datasets using SelfClean.
These examples analyze different benchmark datasets such as:

- <a href="https://github.com/fastai/imagenette">Imagenette</a> 🖼️ (Open in <a href="https://nbviewer.org/github/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_Imagenette.ipynb">NBViewer</a> | <a href="https://github.com/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_Imagenette.ipynb">GitHub</a> | <a href="https://colab.research.google.com/github/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_Imagenette.ipynb">Colab</a>)
- <a href="https://www.robots.ox.ac.uk/~vgg/data/pets/">Oxford-IIIT Pet</a> 🐶 (Open in <a href="https://nbviewer.org/github/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_OxfordIIITPet.ipynb">NBViewer</a> | <a href="https://github.com/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_OxfordIIITPet.ipynb">GitHub</a> | <a href="https://colab.research.google.com/github/Digital-Dermatology/SelfClean/blob/main/examples/Investigate_OxfordIIITPet.ipynb">Colab</a>)

## Development Environment
Run `make` for a list of possible targets.

### Installation
Run these commands to install the project:
Run these commands to install the requirements for the development environment:
```bash
make init
make install
Expand All @@ -28,7 +68,7 @@ To run linters on all files:
pre-commit run --all-files
```

### Code and test conventions
We use the following packages for code and test conventions:
- `black` for code style
- `isort` for import sorting
- `pytest` for running tests
99 changes: 65 additions & 34 deletions examples/Investigate_Imagenette.ipynb

Large diffs are not rendered by default.

115 changes: 73 additions & 42 deletions examples/Investigate_OxfordIIITPet.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/cleaner/base_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from torch.utils.data import Dataset

from ..cleaner.issue_manager import IssueManager


class BaseCleaner(ABC):
@abstractmethod
Expand All @@ -18,5 +20,5 @@ def fit(
raise NotImplementedError()

@abstractmethod
def predict(self) -> dict:
def predict(self) -> IssueManager:
raise NotImplementedError()
113 changes: 57 additions & 56 deletions src/cleaner/selfclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"model": {
"out_dim": 4096,
"emb_dim": 192,
"base_model": "pretrained_imagenet_vit_tiny",
"base_model": "pretrained_imagenet_dino",
"model_type": "VIT",
"use_bn_in_head": False,
"norm_last_layer": True,
Expand Down Expand Up @@ -131,7 +131,7 @@ def run_on_image_folder(
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
num_workers: Optional[int] = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand Down Expand Up @@ -168,12 +168,12 @@ def run_on_image_folder(
def run_on_dataset(
self,
dataset,
epochs: int = 100,
epochs: int = 10,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
num_workers: Optional[int] = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand Down Expand Up @@ -211,7 +211,7 @@ def _run(
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
num_workers: int = os.cpu_count(),
num_workers: Optional[int] = os.cpu_count(),
pretraining_type: PretrainingType = PretrainingType.DINO,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
# embedding
Expand All @@ -222,58 +222,59 @@ def _run(
wandb_logging: bool = False,
wandb_project_name: str = "SelfClean",
):
if self.model is None:
if pretraining_type is PretrainingType.DINO:
self.model = self.train_dino(
dataset=dataset,
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
save_every_n_epochs=save_every_n_epochs,
work_dir=work_dir,
hyperparameters=hyperparameters,
num_workers=num_workers,
additional_run_info=additional_run_info,
wandb_logging=wandb_logging,
wandb_project_name=wandb_project_name,
)
elif (
pretraining_type is PretrainingType.IMAGENET
or pretraining_type is PretrainingType.IMAGENET_VIT
):
self.model = Embedder.load_pretrained(pretraining_type.value)
else:
raise ValueError(f"Unknown pretraining type: {pretraining_type}")
if not self.cleaner.is_fitted:
if self.model is None:
if pretraining_type is PretrainingType.DINO:
self.model = self.train_dino(
dataset=dataset,
epochs=epochs,
batch_size=batch_size,
ssl_pre_training=ssl_pre_training,
save_every_n_epochs=save_every_n_epochs,
work_dir=work_dir,
hyperparameters=hyperparameters,
num_workers=num_workers,
additional_run_info=additional_run_info,
wandb_logging=wandb_logging,
wandb_project_name=wandb_project_name,
)
elif (
pretraining_type is PretrainingType.IMAGENET
or pretraining_type is PretrainingType.IMAGENET_VIT
):
self.model = Embedder.load_pretrained(pretraining_type.value)
else:
raise ValueError(f"Unknown pretraining type: {pretraining_type}")

set_dataset_transformation(dataset=dataset, transform=self.base_transform)
torch_dataset = DataLoader(
dataset,
batch_size=batch_size,
drop_last=False,
shuffle=False,
)
emb_space, labels = embed_dataset(
torch_dataset=torch_dataset,
model=self.model,
n_layers=n_layers,
normalize=apply_l2_norm,
memmap=self.memmap,
memmap_path=self.memmap_path,
tqdm_desc="Creating dataset representation",
return_only_embedding_and_labels=True,
)
# for default datasets we can set the paths manually
paths = None
if hasattr(dataset, "_image_files") and paths is None:
paths = dataset._image_files
set_dataset_transformation(dataset=dataset, transform=self.base_transform)
torch_dataset = DataLoader(
dataset,
batch_size=batch_size,
drop_last=False,
shuffle=False,
)
emb_space, labels = embed_dataset(
torch_dataset=torch_dataset,
model=self.model,
n_layers=n_layers,
normalize=apply_l2_norm,
memmap=self.memmap,
memmap_path=self.memmap_path,
tqdm_desc="Creating dataset representation",
return_only_embedding_and_labels=True,
)
# for default datasets we can set the paths manually
paths = None
if hasattr(dataset, "_image_files") and paths is None:
paths = dataset._image_files

self.cleaner.fit(
emb_space=np.asarray(emb_space),
labels=np.asarray(labels),
paths=np.asarray(paths) if paths is not None else paths,
dataset=dataset,
class_labels=dataset.classes if hasattr(dataset, "classes") else None,
)
self.cleaner.fit(
emb_space=np.asarray(emb_space),
labels=np.asarray(labels),
paths=np.asarray(paths) if paths is not None else paths,
dataset=dataset,
class_labels=dataset.classes if hasattr(dataset, "classes") else None,
)
return self.cleaner.predict()

def train_dino(
Expand All @@ -285,7 +286,7 @@ def train_dino(
save_every_n_epochs: int = 10,
work_dir: Optional[str] = None,
hyperparameters: dict = DINO_STANDARD_HYPERPARAMETERS,
num_workers: int = os.cpu_count(),
num_workers: Optional[int] = os.cpu_count(),
# logging
additional_run_info: str = "",
wandb_logging: bool = False,
Expand Down
24 changes: 14 additions & 10 deletions src/cleaner/selfclean_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self.plot_distribution = plot_distribution
self.plot_top_N = plot_top_N
self.figsize = figsize
self.is_fitted = False
super().__init__(**kwargs)

def fit(
Expand Down Expand Up @@ -153,40 +154,43 @@ def fit(
self.p_distances[:] = self.distance_matrix[
~np.tril(np.ones((self.N, self.N), dtype=bool))
]
self.is_fitted = True
return self

def predict(self) -> IssueManager:
pred_nd_scores, pred_nd_indices = self.get_near_duplicate_ranking()
pred_oods_scores, pred_oods_indices = self.get_irrelevant_ranking()
pred_irr_scores, pred_irr_indices = self.get_irrelevant_ranking()
pred_lbl_errs_scores, pred_lbl_errs_indices = self.get_label_error_ranking()

# transform labels using class names if given
if self.labels is not None:
self.labels = [
# transform labels using class names if given
labels = [
self.class_labels[x] if self.class_labels is not None else x
for x in self.labels
]
else:
labels = self.labels

if self.plot_top_N is not None and self.dataset is not None:
plot_inspection_result(
pred_dups_indices=pred_nd_indices,
pred_oods_indices=pred_oods_indices,
pred_dup_indices=pred_nd_indices,
pred_irr_indices=pred_irr_indices,
pred_lbl_errs_indices=pred_lbl_errs_indices,
dataset=self.dataset,
labels=self.labels,
labels=labels,
plot_top_N=self.plot_top_N,
output_path=self.output_path,
figsize=self.figsize,
)

meta_data_dict = {
"path": self.paths,
"label": self.labels,
"label": labels,
}
return_dict = {
"irrelevants": {
"indices": pred_oods_indices,
"scores": pred_oods_scores,
"indices": pred_irr_indices,
"scores": pred_irr_scores,
},
"near_duplicates": {
"indices": pred_nd_indices,
Expand All @@ -200,7 +204,7 @@ def predict(self) -> IssueManager:
return_dict = self.perform_auto_cleaning(
return_dict=return_dict,
pred_near_duplicate_scores=pred_nd_scores,
pred_irrelevant_scores=pred_oods_scores,
pred_irrelevant_scores=pred_irr_scores,
pred_label_error_scores=pred_lbl_errs_scores,
output_path=self.output_path,
)
Expand Down
2 changes: 1 addition & 1 deletion src/ssl_library
10 changes: 5 additions & 5 deletions src/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@


def plot_inspection_result(
pred_dups_indices: np.ndarray,
pred_oods_indices: np.ndarray,
pred_dup_indices: np.ndarray,
pred_irr_indices: np.ndarray,
dataset: Dataset,
plot_top_N: int,
pred_lbl_errs_indices: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
labels: Optional[Union[np.ndarray, list]] = None,
output_path: Optional[Union[str, Path]] = None,
figsize: tuple = (10, 8),
):
rows = 4 if pred_lbl_errs_indices is not None else 3
fig, ax = plt.subplots(rows, plot_top_N, figsize=figsize)
for i, (idx1, idx2) in enumerate(pred_dups_indices[:plot_top_N]):
for i, (idx1, idx2) in enumerate(pred_dup_indices[:plot_top_N]):
ax[0, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx1)][0]))
)
Expand All @@ -35,7 +35,7 @@ def plot_inspection_result(
ax[0, i].set_title(f"Ranking: {i+1}, Idx: {int(idx1)}", fontsize=6)
ax[1, i].set_title(f"Idx: {int(idx2)}", fontsize=6)

for i, idx in enumerate(pred_oods_indices[:plot_top_N]):
for i, idx in enumerate(pred_irr_indices[:plot_top_N]):
ax[2, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx)][0]))
)
Expand Down
13 changes: 13 additions & 0 deletions tests/integration_tests/test_selfclean_IT.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import tempfile
import unittest
from pathlib import Path

from torchvision.datasets import FakeData

Expand Down Expand Up @@ -32,6 +33,18 @@ def test_run_with_files_dino_in_workdir(self):
)
self._check_output(out_dict)

def test_run_with_files_dino_with_output_path(self):
temp_work_dir = tempfile.TemporaryDirectory()
selfclean = SelfClean(output_path=str(Path(temp_work_dir.name) / "output"))
out_dict = selfclean.run_on_image_folder(
input_path=testfiles_path,
pretraining_type=PretrainingType.DINO,
work_dir=temp_work_dir.name,
epochs=1,
num_workers=4,
)
self._check_output(out_dict)

def test_run_with_files_dino_wo_pretraining(self):
selfclean = SelfClean()
out_dict = selfclean.run_on_image_folder(
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/cleaner/test_selfclean_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def setUp(self):

def test_fit(self):
cleaner = SelfCleanCleaner(memmap=False)
self.assertEqual(cleaner.is_fitted, False)
cleaner.fit(emb_space=self.emb_space, labels=self.labels)
self.assertEqual(cleaner.is_fitted, True)
self.assertIsInstance(cleaner, BaseCleaner)
self.assertIsNotNone(cleaner.distance_matrix)
self.assertIsNotNone(cleaner.p_distances)
Expand Down

0 comments on commit c812821

Please sign in to comment.