Skip to content

Commit

Permalink
added possibility to approximate NDs
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianGroeger96 committed Sep 12, 2024
1 parent e5f5405 commit 15bcdf2
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 7 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-runtime

RUN apt-get update && apt-get install -y apt-transport-https
RUN apt-get update
RUN apt-get install -y apt-transport-https
RUN apt-get install -y libtcmalloc-minimal4
RUN apt-get install -y libomp-dev
RUN apt-get install -y sox
RUN apt-get install -y git

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ scikit-image
codecov
jupyter
loguru
faiss-cpu
faiss-gpu
33 changes: 33 additions & 0 deletions src/cleaner/near_duplicates/embedding_distance_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@


class EmbeddingDistanceMixin(BaseNearDuplicateMixin):
def __init__(
self,
approx_no_neighbors: int = 100,
**kwargs,
):
super().__init__(**kwargs)
self.approx_no_neighbors = approx_no_neighbors

def get_near_duplicate_ranking(self) -> Tuple[np.ndarray, np.ndarray]:
if self.memmap:
score_file = self.memmap_path / "near_duplicate_scores.dat"
Expand Down Expand Up @@ -77,3 +85,28 @@ def get_near_duplicate_ranking(self) -> Tuple[np.ndarray, np.ndarray]:
title="Distribution of near-duplicates",
)
return scores_near_dup, indices_near_dup

def get_approx_near_duplicate_ranking(self):
import copy

import faiss
import pandas as pd

# faiss expects all arrays to be `float32`
_emb_space = copy.deepcopy(self.emb_space)
_emb_space = _emb_space.astype("float32")
# create a `faiss` index with cosine distance
index = faiss.IndexFlat(self.D, faiss.METRIC_INNER_PRODUCT)
faiss.normalize_L2(_emb_space)
index.add(_emb_space)
# search the nearest neighbors
distances, indices = index.search(_emb_space, self.approx_no_neighbors)
# create the return dataframe
df = pd.DataFrame()
df[[f"nn_idx_{x}" for x in range(self.approx_no_neighbors)]] = indices
df[[f"nn_dist_{x}" for x in range(self.approx_no_neighbors)]] = distances
df = df.reindex(sorted(df.columns, key=lambda x: int(x.split("_")[-1])), axis=1)
df = df.drop(columns=["nn_dist_0"])
df = df.rename(columns={"nn_idx_0": "seed_idx"})
del _emb_space, index
return df
2 changes: 1 addition & 1 deletion src/cleaner/selfclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
def run_on_image_folder(
self,
input_path: Union[str, Path],
epochs: int = 100,
epochs: int = 10,
batch_size: int = 64,
ssl_pre_training: bool = True,
save_every_n_epochs: int = 10,
Expand Down
17 changes: 12 additions & 5 deletions src/cleaner/selfclean_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
# memory management
memmap: bool = True,
memmap_path: Union[Path, str, None] = None,
approximate_nn: bool = False,
# plotting
plot_distribution: bool = False,
plot_top_N: Optional[int] = None,
Expand All @@ -57,6 +58,7 @@ def __init__(
fix_random_seeds(seed=random_seed)

self.memmap = memmap
self.approximate_nn = approximate_nn
self.chunk_size = chunk_size
self.precision_type_distance = precision_type_distance

Expand Down Expand Up @@ -89,6 +91,7 @@ def fit(
dataset: Optional[Dataset] = None,
class_labels: Optional[list] = None,
):
self.emb_space = emb_space
self.labels = labels
self.dataset = dataset
self.paths = paths
Expand Down Expand Up @@ -169,11 +172,15 @@ def predict(
) -> IssueManager:
return_dict = {}
if IssueTypes.NEAR_DUPLICATES in issues_to_detect:
pred_nd_scores, pred_nd_indices = self.get_near_duplicate_ranking()
return_dict["near_duplicates"] = {
"indices": pred_nd_indices,
"scores": pred_nd_scores,
}
if not self.approximate_nn:
pred_nd_scores, pred_nd_indices = self.get_near_duplicate_ranking()
return_dict["near_duplicates"] = {
"indices": pred_nd_indices,
"scores": pred_nd_scores,
}
else:
approx_result_df = self.get_approx_near_duplicate_ranking()
return_dict["approx_near_duplicates"] = approx_result_df
if IssueTypes.IRRELEVANTS in issues_to_detect:
pred_irr_scores, pred_irr_indices = self.get_irrelevant_ranking()
return_dict["irrelevants"] = {
Expand Down
42 changes: 42 additions & 0 deletions tests/unittests/cleaner/test_selfclean_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def test_predict_multi_issues(self):
v = out_dict.get_issues(issue_type)
self.assertIsNone(v)

def test_approx_nearest_duplicates(self):
cleaner = SelfCleanCleaner(
memmap=False,
approximate_nn=True,
approx_no_neighbors=10,
)
cleaner.fit(emb_space=self.emb_space, labels=self.labels)
out_dict = cleaner.predict(issues_to_detect=[IssueTypes.NEAR_DUPLICATES])
for issue_type in ["approx_near_duplicates"]:
v = out_dict.get_issues(issue_type)
self.assertIsNotNone(v)
self.assertEqual(len([x for x in v.columns if "nn_idx_" in x]), 10 - 1)
self.assertEqual(len([x for x in v.columns if "nn_dist_" in x]), 10 - 1)
for issue_type in ["near_duplicates", "irrelevants", "label_errors"]:
v = out_dict.get_issues(issue_type)
self.assertIsNone(v)

def test_approx_nearest_duplicates_w_exact(self):
cleaner = SelfCleanCleaner(
memmap=False,
approximate_nn=True,
approx_no_neighbors=len(self.emb_space),
)
cleaner.fit(emb_space=self.emb_space, labels=self.labels)
out_dict = cleaner.predict(issues_to_detect=[IssueTypes.NEAR_DUPLICATES])
df_approx_nn = out_dict.get_issues("approx_near_duplicates")

# fit without approximation
cleaner.approximate_nn = False
out_dict = cleaner.predict(issues_to_detect=[IssueTypes.NEAR_DUPLICATES])
df_nn = out_dict.get_issues("near_duplicates", return_as_df=True)

# check if they align
for index in range(len(self.emb_space)):
nn = df_nn[
(df_nn["indices_1"] == index) | (df_nn["indices_2"] == index)
].iloc[0]
nn_approx = df_approx_nn[df_approx_nn["seed_idx"] == index].iloc[0]
idx = nn["indices_1"] if nn["indices_1"] != index else nn["indices_2"]
idx_approx = nn_approx["nn_idx_1"]
self.assertEqual(idx, idx_approx)


if __name__ == "__main__":
unittest.main()

0 comments on commit 15bcdf2

Please sign in to comment.