Skip to content

Commit

Permalink
possible to only detect single issue types
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianGroeger96 committed Jul 2, 2024
1 parent 8a16682 commit e5f5405
Show file tree
Hide file tree
Showing 12 changed files with 391 additions and 306 deletions.
185 changes: 92 additions & 93 deletions examples/Investigate_Imagenette.ipynb

Large diffs are not rendered by default.

204 changes: 98 additions & 106 deletions examples/Investigate_OxfordIIITPet.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse_requirements(filename):
name=PACKAGE_NAME,
packages=proj_packages,
package_dir={PACKAGE_NAME: SOURCE_DIRECTORY},
version="0.0.22",
version="0.0.24",
author="Fabian Groeger",
author_email="fabian.groeger@unibas.ch",
description="A holistic self-supervised data cleaning strategy to detect irrelevant samples, near duplicates and label errors.",
Expand Down
50 changes: 27 additions & 23 deletions src/cleaner/auto_cleaning_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy.stats
from loguru import logger

from ..cleaner.issue_manager import IssueManager
from ..utils.plotting import (
plot_frac_cut,
plot_sensitivity,
Expand Down Expand Up @@ -33,46 +34,49 @@ def __init__(

def perform_auto_cleaning(
self,
issue_manger: IssueManager,
return_dict: dict,
pred_near_duplicate_scores: np.ndarray,
pred_irrelevant_scores: np.ndarray,
pred_label_error_scores: Optional[np.ndarray],
output_path: Optional[Union[str, Path]] = None,
):
if self.auto_cleaning:
# Near Duplicates
if output_path is not None:
self.cleaner_kwargs["path"] = (
f"{output_path.stem}_auto_dups{output_path.suffix}"
near_duplicate_issues = issue_manger["near_duplicates"]
if near_duplicate_issues is not None:
if output_path is not None:
self.cleaner_kwargs["path"] = (
f"{output_path.stem}_auto_dups{output_path.suffix}"
)
self.cleaner_kwargs["alpha"] = self.near_duplicate_cut_off
issues_dup = self.fraction_cut(
scores=near_duplicate_issues["scores"],
**self.cleaner_kwargs,
)
self.cleaner_kwargs["alpha"] = self.near_duplicate_cut_off
issues_dup = self.fraction_cut(
scores=pred_near_duplicate_scores,
**self.cleaner_kwargs,
)
return_dict["near_duplicates"]["auto_issues"] = issues_dup
return_dict["near_duplicates"]["auto_issues"] = issues_dup

# Irrelevant Samples
if output_path is not None:
self.cleaner_kwargs["path"] = (
f"{output_path.stem}_auto_oods{output_path.suffix}"
irrelevant_issues = issue_manger["irrelevants"]
if irrelevant_issues is not None:
if output_path is not None:
self.cleaner_kwargs["path"] = (
f"{output_path.stem}_auto_oods{output_path.suffix}"
)
self.cleaner_kwargs["alpha"] = self.irrelevant_cut_off
issues_ood = self.fraction_cut(
scores=irrelevant_issues["scores"],
**self.cleaner_kwargs,
)
self.cleaner_kwargs["alpha"] = self.irrelevant_cut_off
issues_ood = self.fraction_cut(
scores=pred_irrelevant_scores,
**self.cleaner_kwargs,
)
return_dict["irrelevants"]["auto_issues"] = issues_ood
return_dict["irrelevants"]["auto_issues"] = issues_ood

# Label Errors
if pred_label_error_scores is not None:
label_error_issues = issue_manger["label_errors"]
if label_error_issues is not None:
if output_path is not None:
self.cleaner_kwargs["path"] = (
f"{output_path.stem}_auto_lbls{output_path.suffix}"
)
self.cleaner_kwargs["alpha"] = self.label_error_cut_off
issues_lbl = self.fraction_cut(
scores=pred_label_error_scores,
scores=label_error_issues["scores"],
**self.cleaner_kwargs,
)
return_dict["label_errors"]["auto_issues"] = issues_lbl
Expand Down
10 changes: 3 additions & 7 deletions src/cleaner/issue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@ class IssueTypes(Enum):
class IssueManager:
def __init__(self, issue_dict: dict, meta_data_dict: Optional[dict] = None):
self.issue_dict = issue_dict
for issue_type in IssueTypes:
assert (
issue_type.value in self.issue_dict
), f"{issue_type.value} not found in given dict."
self.meta_data_dict = meta_data_dict if meta_data_dict is not None else {}

def get_issues(
self,
issue_type: Union[str, IssueTypes],
return_as_df: bool = False,
) -> Union[np.ndarray, pd.DataFrame]:
) -> Union[np.ndarray, pd.DataFrame, None]:
if issue_type is type(IssueTypes):
issue_type = issue_type.value

sel_issues = self.issue_dict.get(issue_type)
sel_issues = self.issue_dict.get(issue_type, None)
if sel_issues is None:
raise ValueError(f"Issue type: {issue_type} not found.")
return sel_issues

if return_as_df:
logger.warning("Returning as dataframe requires extensive memory.")
Expand Down
73 changes: 41 additions & 32 deletions src/cleaner/selfclean_cleaner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import tempfile
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import scienceplots # noqa: F401
Expand All @@ -13,7 +13,7 @@
from ..cleaner.auto_cleaning_mixin import AutoCleaningMixin
from ..cleaner.base_cleaner import BaseCleaner
from ..cleaner.irrelevants.lad_mixin import LADIrrelevantMixin
from ..cleaner.issue_manager import IssueManager
from ..cleaner.issue_manager import IssueManager, IssueTypes
from ..cleaner.label_errors.intra_extra_distance_mixin import (
IntraExtraDistanceLabelErrorMixin,
)
Expand Down Expand Up @@ -159,10 +159,34 @@ def fit(
self.is_fitted = True
return self

def predict(self) -> IssueManager:
pred_nd_scores, pred_nd_indices = self.get_near_duplicate_ranking()
pred_irr_scores, pred_irr_indices = self.get_irrelevant_ranking()
pred_lbl_errs_scores, pred_lbl_errs_indices = self.get_label_error_ranking()
def predict(
self,
issues_to_detect: List[IssueTypes] = [
IssueTypes.NEAR_DUPLICATES,
IssueTypes.IRRELEVANTS,
IssueTypes.LABEL_ERRORS,
],
) -> IssueManager:
return_dict = {}
if IssueTypes.NEAR_DUPLICATES in issues_to_detect:
pred_nd_scores, pred_nd_indices = self.get_near_duplicate_ranking()
return_dict["near_duplicates"] = {
"indices": pred_nd_indices,
"scores": pred_nd_scores,
}
if IssueTypes.IRRELEVANTS in issues_to_detect:
pred_irr_scores, pred_irr_indices = self.get_irrelevant_ranking()
return_dict["irrelevants"] = {
"indices": pred_irr_indices,
"scores": pred_irr_scores,
}
if IssueTypes.LABEL_ERRORS in issues_to_detect:
pred_lbl_errs_scores, pred_lbl_errs_indices = self.get_label_error_ranking()
if pred_lbl_errs_scores is not None and pred_lbl_errs_indices is not None:
return_dict["label_errors"] = {
"indices": pred_lbl_errs_indices,
"scores": pred_lbl_errs_scores,
}

if self.labels is not None:
# transform labels using class names if given
Expand All @@ -172,42 +196,27 @@ def predict(self) -> IssueManager:
]
else:
labels = self.labels
# create the manger for the issues to pass to plotting and return
issue_manager = IssueManager(
issue_dict=return_dict,
meta_data_dict={
"path": self.paths,
"label": labels,
},
)

if self.plot_top_N is not None and self.dataset is not None:
plot_inspection_result(
pred_dup_indices=pred_nd_indices,
pred_irr_indices=pred_irr_indices,
pred_lbl_errs_indices=pred_lbl_errs_indices,
issue_manger=issue_manager,
dataset=self.dataset,
labels=labels,
plot_top_N=self.plot_top_N,
output_path=self.output_path,
figsize=self.figsize,
)

meta_data_dict = {
"path": self.paths,
"label": labels,
}
return_dict = {
"irrelevants": {
"indices": pred_irr_indices,
"scores": pred_irr_scores,
},
"near_duplicates": {
"indices": pred_nd_indices,
"scores": pred_nd_scores,
},
"label_errors": {
"indices": pred_lbl_errs_indices,
"scores": pred_lbl_errs_scores,
},
}
return_dict = self.perform_auto_cleaning(
issue_manger=issue_manager,
return_dict=return_dict,
pred_near_duplicate_scores=pred_nd_scores,
pred_irrelevant_scores=pred_irr_scores,
pred_label_error_scores=pred_lbl_errs_scores,
output_path=self.output_path,
)
return IssueManager(issue_dict=return_dict, meta_data_dict=meta_data_dict)
return issue_manager
2 changes: 1 addition & 1 deletion src/ssl_library
101 changes: 66 additions & 35 deletions src/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,92 @@
from torch.utils.data import Dataset
from torchvision import transforms

from ..cleaner.issue_manager import IssueManager
from ..ssl_library.src.utils.logging import create_subtitle, denormalize_image


def plot_inspection_result(
pred_dup_indices: np.ndarray,
pred_irr_indices: np.ndarray,
issue_manger: IssueManager,
dataset: Dataset,
plot_top_N: int,
pred_lbl_errs_indices: Optional[np.ndarray] = None,
labels: Optional[Union[np.ndarray, list]] = None,
output_path: Optional[Union[str, Path]] = None,
figsize: tuple = (10, 8),
):
rows = 4 if pred_lbl_errs_indices is not None else 3
rows = len(issue_manger.keys)
if issue_manger["near_duplicates"] is not None:
rows += 1
fig, ax = plt.subplots(rows, plot_top_N, figsize=figsize)
for i, (idx1, idx2) in enumerate(pred_dup_indices[:plot_top_N]):
ax[0, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx1)][0]))
)
ax[1, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx2)][0]))
)
ax[0, i].set_xticks([])
ax[0, i].set_yticks([])
ax[1, i].set_xticks([])
ax[1, i].set_yticks([])
ax[0, i].set_title(f"Ranking: {i+1}, Idx: {int(idx1)}", fontsize=6)
ax[1, i].set_title(f"Idx: {int(idx2)}", fontsize=6)

for i, idx in enumerate(pred_irr_indices[:plot_top_N]):
ax[2, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx)][0]))
)
ax[2, i].set_title(f"Ranking: {i+1}, Idx: {int(idx)}", fontsize=6)
ax[2, i].set_xticks([])
ax[2, i].set_yticks([])
ax_idx = 0

if pred_lbl_errs_indices is not None:
for i, idx in enumerate(pred_lbl_errs_indices[:plot_top_N]):
near_duplicate_issues = issue_manger["near_duplicates"]
if near_duplicate_issues is not None:
for i, (idx1, idx2) in enumerate(near_duplicate_issues["indices"][:plot_top_N]):
ax[ax_idx, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx1)][0]))
)
ax[ax_idx + 1, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx2)][0]))
)
ax[ax_idx, i].set_xticks([])
ax[ax_idx, i].set_yticks([])
ax[ax_idx + 1, i].set_xticks([])
ax[ax_idx + 1, i].set_yticks([])
ax[ax_idx, i].set_title(f"Ranking: {i+1}, Idx: {int(idx1)}", fontsize=6)
ax[ax_idx + 1, i].set_title(f"Idx: {int(idx2)}", fontsize=6)
ax_idx += 2

irrelevant_issues = issue_manger["irrelevants"]
if irrelevant_issues is not None:
for i, idx in enumerate(irrelevant_issues["indices"][:plot_top_N]):
ax[ax_idx, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx)][0]))
)
ax[ax_idx, i].set_title(f"Ranking: {i+1}, Idx: {int(idx)}", fontsize=6)
ax[ax_idx, i].set_xticks([])
ax[ax_idx, i].set_yticks([])
ax_idx += 1

label_error_issues = issue_manger["label_errors"]
if label_error_issues is not None:
for i, idx in enumerate(label_error_issues["indices"][:plot_top_N]):
class_label = labels[idx] if labels is not None else None
ax[3, i].imshow(
ax[ax_idx, i].imshow(
transforms.ToPILImage()(denormalize_image(dataset[int(idx)][0]))
)
ax[3, i].set_title(
ax[ax_idx, i].set_title(
f"Ranking: {i+1}\nIdx: {int(idx)}\nLbl: {class_label}",
fontsize=6,
)
ax[3, i].set_xticks([])
ax[3, i].set_yticks([])
ax[ax_idx, i].set_xticks([])
ax[ax_idx, i].set_yticks([])

ax_idx = 0
grid = plt.GridSpec(rows, plot_top_N)
create_subtitle(fig, grid[0, ::], "Near-Duplicate Ranking", fontsize=12)
create_subtitle(fig, grid[2, ::], "Irrelevant Samples Ranking", fontsize=12)
if pred_lbl_errs_indices is not None:
create_subtitle(fig, grid[3, ::], "Label Error Ranking", fontsize=12)
if near_duplicate_issues is not None:
create_subtitle(
fig,
grid[ax_idx, ::],
"Near-Duplicate Ranking",
fontsize=12,
)
ax_idx += 2
if irrelevant_issues is not None:
create_subtitle(
fig,
grid[ax_idx, ::],
"Irrelevant Samples Ranking",
fontsize=12,
)
ax_idx += 1
if label_error_issues is not None:
create_subtitle(
fig,
grid[ax_idx, ::],
"Label Error Ranking",
fontsize=12,
)

fig.tight_layout()
if output_path is not None:
plt.savefig(output_path, bbox_inches="tight")
Expand Down
14 changes: 14 additions & 0 deletions tests/integration_tests/test_selfclean_IT.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ def test_run_with_dataset(self):
)
self._check_output(out_dict)

def test_run_with_plotting(self):
fake_dataset = FakeData(size=20)
selfclean = SelfClean(
memmap=False,
plot_distribution=True,
plot_top_N=7,
)
out_dict = selfclean.run_on_dataset(
dataset=fake_dataset,
epochs=1,
num_workers=4,
)
self._check_output(out_dict)

def _check_output(self, out_dict):
for issue_type in ["irrelevants", "near_duplicates", "label_errors"]:
v = out_dict.get_issues(issue_type)
Expand Down
5 changes: 2 additions & 3 deletions tests/unittests/cleaner/test_auto_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ def test_predict_auto_cleaning_without_labels(self):
cleaner = SelfCleanCleaner(memmap=False, auto_cleaning=True)
cleaner.fit(emb_space=self.emb_space)
out_dict = cleaner.predict()
for issue_type in ["irrelevants", "near_duplicates", "label_errors"]:
for issue_type in ["irrelevants", "near_duplicates"]:
v = out_dict.get_issues(issue_type)
self.assertIsNotNone(v)
self.assertTrue("indices" in v)
self.assertTrue("scores" in v)
self.assertIsNone(out_dict.get_issues("label_errors")["indices"])
self.assertIsNone(out_dict.get_issues("label_errors")["scores"])
self.assertIsNone(out_dict.get_issues("label_errors"))

def test_predict_auto_cleaning_with_plotting(self):
cleaner = SelfCleanCleaner(
Expand Down
Loading

0 comments on commit e5f5405

Please sign in to comment.