diff --git a/CHANGELOG.md b/CHANGELOG.md index 83b1d02579..57412029e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Migrate DVC v3.0.0 () +- Add Prune API + () ### Enhancements - Enhance import performance for built-in plugins diff --git a/requirements-core.txt b/requirements-core.txt index 30a0788cb9..1cbf503015 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -54,3 +54,6 @@ tabulate # Model inference launcher from the dedicated inference server ovmsclient tritonclient[all] + +# prune +scikit-learn diff --git a/src/datumaro/cli/commands/__init__.py b/src/datumaro/cli/commands/__init__.py index 86168fe596..f6eb65ba85 100644 --- a/src/datumaro/cli/commands/__init__.py +++ b/src/datumaro/cli/commands/__init__.py @@ -16,6 +16,7 @@ info, merge, patch, + prune, stats, transform, validate, @@ -36,11 +37,12 @@ def get_non_project_commands(): ("dinfo", info, "Print dataset info"), ("download", download, "Download a publicly available dataset"), ("explain", explain, "Run Explainable AI algorithm for model"), + ("explore", explore, "Explore similar datasetitems of query"), ("filter", filter, "Filter dataset items"), ("generate", generate, "Generate synthetic dataset"), ("merge", merge, "Merge datasets"), ("patch", patch, "Update dataset from another one"), - ("explore", explore, "Explore similar datasetitems of query"), + ("prune", prune, "Prune dataset"), ("stats", stats, "Compute dataset statistics"), ("transform", transform, "Modify dataset items"), ("validate", validate, "Validate dataset"), diff --git a/src/datumaro/cli/commands/explore.py b/src/datumaro/cli/commands/explore.py index 700788a115..35d99f4426 100644 --- a/src/datumaro/cli/commands/explore.py +++ b/src/datumaro/cli/commands/explore.py @@ -8,8 +8,8 @@ import os.path as osp import shutil +from datumaro.components.algorithms.hash_key_inference.explorer import Explorer from datumaro.components.errors import ProjectNotFoundError -from datumaro.components.explorer import Explorer from datumaro.util import str_to_bool from datumaro.util.scope import scope_add, scoped diff --git a/src/datumaro/cli/commands/prune.py b/src/datumaro/cli/commands/prune.py new file mode 100644 index 0000000000..d15ec605a2 --- /dev/null +++ b/src/datumaro/cli/commands/prune.py @@ -0,0 +1,121 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import logging as log + +from datumaro.components.algorithms.hash_key_inference.prune import Prune +from datumaro.components.errors import ProjectNotFoundError +from datumaro.util.scope import scope_add, scoped + +from ..util import MultilineFormatter +from ..util.project import load_project, parse_full_revpath + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor( + help="Prune dataset and make a representative subset", + description=""" + Apply data pruning to a dataset.|n + The command can be useful if you have to extract representative subset. + |n + The current project (-p/--project) is used as a context for plugins + and models. It is used when there is a dataset path in target. + When not specified, the current project's working tree is used.|n + |n + By default, datasets are updated in-place. The '-o/--output-dir' + option can be used to specify another output directory. When + updating in-place, use the '--overwrite' parameter (in-place + updates fail by default to prevent data loss), unless a project + target is modified.|n + |n + The command can be applied to a dataset or a project build target, + a stage or the combined 'project' target, in which case all the + targets will be affected.|n + |n + Examples:|n + - Prune dataset with selecting random and ratio 80%:|n + |s|s%(prog)s -m random -r 0.8|n + - Prune dataset with clustering in image hash and ratio 50%:|n + |s|s%(prog)s -m query_clust -h img -r 0.5| + - Prune dataset based on entropy with clustering in image hash and ratio 50%:|n + |s|s%(prog)s -m entropy -h img -r 0.5| + """, + formatter_class=MultilineFormatter, + ) + + parser.add_argument("target", nargs="?", help="Target dataset revpath (default: project)") + parser.add_argument("-m", "--method", dest="method", help="Method to apply to the dataset") + parser.add_argument( + "-r", "--ratio", type=float, dest="ratio", help="How much to remain dataset after pruning" + ) + parser.add_argument( + "--hash-type", + type=str, + dest="hash_type", + default="img", + help="Hashtype to extract feature from data information between image and text(label)", + ) + parser.add_argument( + "-p", + "--project", + dest="project_dir", + help="Directory of the project to operate on (default: current dir)", + ) + + parser.add_argument( + "-o", + "--output-dir", + dest="dst_dir", + help=""" + Output directory. Can be omitted for main project targets + (i.e. data sources and the 'project' target, but not + intermediate stages) and dataset targets. + If not specified, the results will be saved inplace. + """, + ) + parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing files in the save directory" + ) + parser.set_defaults(command=prune_command) + + return parser + + +def get_sensitive_args(): + return { + prune_command: [ + "target", + "method", + "ratio", + "hash_type", + "project_dir", + "dst_dir", + ] + } + + +@scoped +def prune_command(args): + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if args.project_dir: + raise + + targets = [args.target] if args.target else list(project.working_tree.sources) + + source_dataset = [parse_full_revpath(target, project)[0] for target in targets][0] + + prune = Prune(source_dataset, cluster_method=args.method, hash_type=args.hash_type) + + source_dataset.save(source_dataset.data_path, save_media=True, save_hashkey_meta=True) + + result = prune.get_pruned(args.ratio) + + dst_dir = args.dst_dir or source_dataset.data_path + result.save(dst_dir, save_media=True) + + log.info("Results have been saved to '%s'" % dst_dir) diff --git a/src/datumaro/components/algorithms/hash_key_inference/__init__.py b/src/datumaro/components/algorithms/hash_key_inference/__init__.py new file mode 100644 index 0000000000..ff847f0120 --- /dev/null +++ b/src/datumaro/components/algorithms/hash_key_inference/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT diff --git a/src/datumaro/components/algorithms/hash_key_inference/base.py b/src/datumaro/components/algorithms/hash_key_inference/base.py new file mode 100644 index 0000000000..05f73d0abc --- /dev/null +++ b/src/datumaro/components/algorithms/hash_key_inference/base.py @@ -0,0 +1,37 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import Sequence + +from datumaro.components.dataset import Dataset +from datumaro.plugins.explorer import ExplorerLauncher + + +class HashInference: + def __init__(self, *datasets: Sequence[Dataset]) -> None: + pass + + @property + def model(self): + if self._model is None: + self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32") + return self._model + + @property + def text_model(self): + if self._text_model is None: + self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32") + return self._text_model + + def _compute_hash_key(self, datasets, datasets_to_infer): + for dataset_to_infer in datasets_to_infer: + if dataset_to_infer: + dataset_to_infer.run_model(self.model, append_annotation=True) + for dataset, dataset_to_infer in zip(datasets, datasets_to_infer): + updated_items = [ + dataset.get(item.id, item.subset).wrap(annotations=item.annotations) + for item in dataset_to_infer + ] + dataset.update(updated_items) + return datasets diff --git a/src/datumaro/components/explorer.py b/src/datumaro/components/algorithms/hash_key_inference/explorer.py similarity index 72% rename from src/datumaro/components/explorer.py rename to src/datumaro/components/algorithms/hash_key_inference/explorer.py index a842a14eba..a89b5eafdd 100644 --- a/src/datumaro/components/explorer.py +++ b/src/datumaro/components/algorithms/hash_key_inference/explorer.py @@ -6,32 +6,18 @@ import numpy as np +from datumaro.components.algorithms.hash_key_inference.base import HashInference +from datumaro.components.algorithms.hash_key_inference.hashkey_util import ( + calculate_hamming, + select_uninferenced_dataset, +) from datumaro.components.annotation import HashKey from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.errors import DatumaroError, MediaTypeError -from datumaro.components.media import MediaElement -from datumaro.plugins.explorer import ExplorerLauncher -def calculate_hamming(B1, B2): - """ - :param B1: vector [n] - :param B2: vector [r*n] - :return: hamming distance [r] - """ - return np.count_nonzero(B1 != B2, axis=1) - - -def select_uninferenced_dataset(dataset): - uninferenced_dataset = Dataset(media_type=MediaElement) - for item in dataset: - if not any(isinstance(annotation, HashKey) for annotation in item.annotations): - uninferenced_dataset.put(item) - return uninferenced_dataset - - -class Explorer: +class Explorer(HashInference): def __init__( self, *datasets: Sequence[Dataset], @@ -54,7 +40,7 @@ def __init__( item_list = [] datasets_to_infer = [select_uninferenced_dataset(dataset) for dataset in datasets] - datasets = self.compute_hash_key(datasets, datasets_to_infer) + datasets = self._compute_hash_key(datasets, datasets_to_infer) for dataset in datasets: for item in dataset: @@ -75,26 +61,6 @@ def __init__( self._database_keys = np.stack(database_keys, axis=0) self._item_list = item_list - @property - def model(self): - if self._model is None: - self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32") - return self._model - - @property - def text_model(self): - if self._text_model is None: - self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32") - return self._text_model - - def compute_hash_key(self, datasets, datasets_to_infer): - for dataset in datasets_to_infer: - if len(dataset) > 0: - dataset.run_model(self.model, append_annotation=True) - for dataset, dataset_to_infer in zip(datasets, datasets_to_infer): - dataset.update(dataset_to_infer) - return datasets - def explore_topk( self, query: Union[DatasetItem, str], diff --git a/src/datumaro/components/algorithms/hash_key_inference/hashkey_util.py b/src/datumaro/components/algorithms/hash_key_inference/hashkey_util.py new file mode 100644 index 0000000000..9c9d09d8c0 --- /dev/null +++ b/src/datumaro/components/algorithms/hash_key_inference/hashkey_util.py @@ -0,0 +1,174 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import numpy as np + +from datumaro.components.annotation import HashKey +from datumaro.components.dataset import Dataset +from datumaro.components.media import MediaElement + +templates = [ + "a photo of a {}.", +] + +cifar10_templates = [ + "a photo of a {}.", + "a blurry photo of a {}.", + "a black and white photo of a {}.", + "a low contrast photo of a {}.", + "a high contrast photo of a {}.", + "a bad photo of a {}.", + "a good photo of a {}.", + "a photo of a small {}.", + "a photo of a big {}.", + "a photo of the {}.", + "a blurry photo of the {}.", + "a black and white photo of the {}.", + "a low contrast photo of the {}.", + "a high contrast photo of the {}.", + "a bad photo of the {}.", + "a good photo of the {}.", + "a photo of the small {}.", + "a photo of the big {}.", +] + +cifar100_templates = [ + "a photo of a {}.", + "a blurry photo of a {}.", + "a black and white photo of a {}.", + "a low contrast photo of a {}.", + "a high contrast photo of a {}.", + "a bad photo of a {}.", + "a good photo of a {}.", + "a photo of a small {}.", + "a photo of a big {}.", + "a photo of the {}.", + "a blurry photo of the {}.", + "a black and white photo of the {}.", + "a low contrast photo of the {}.", + "a high contrast photo of the {}.", + "a bad photo of the {}.", + "a good photo of the {}.", + "a photo of the small {}.", + "a photo of the big {}.", +] + +caltech101_templates = [ + "a photo of a {}.", + "a painting of a {}.", + "a plastic {}.", + "a sculpture of a {}.", + "a sketch of a {}.", + "a tattoo of a {}.", + "a toy {}.", + "a rendition of a {}.", + "a embroidered {}.", + "a cartoon {}.", + "a {} in a video game.", + "a plushie {}.", + "a origami {}.", + "art of a {}.", + "graffiti of a {}.", + "a drawing of a {}.", + "a doodle of a {}.", + "a photo of the {}.", + "a painting of the {}.", + "the plastic {}.", + "a sculpture of the {}.", + "a sketch of the {}.", + "a tattoo of the {}.", + "the toy {}.", + "a rendition of the {}.", + "the embroidered {}.", + "the cartoon {}.", + "the {} in a video game.", + "the plushie {}.", + "the origami {}.", + "art of the {}.", + "graffiti of the {}.", + "a drawing of the {}.", + "a doodle of the {}.", +] + +eurosat_templates = [ + "a centered satellite photo of {}.", + "a centered satellite photo of a {}.", + "a centered satellite photo of the {}.", +] + +flowers101_templates = [ + "a photo of a {}, a type of flower.", +] + +food101_templates = [ + "a photo of {}, a type of food.", +] + +kitti_templates = [ + "{}", +] + +kinetics_templates = [ + "a photo of {}.", + "a photo of a person {}.", + "a photo of a person using {}.", + "a photo of a person doing {}.", + "a photo of a person during {}.", + "a photo of a person performing {}.", + "a photo of a person practicing {}.", + "a video of {}.", + "a video of a person {}.", + "a video of a person using {}.", + "a video of a person doing {}.", + "a video of a person during {}.", + "a video of a person performing {}.", + "a video of a person practicing {}.", + "a example of {}.", + "a example of a person {}.", + "a example of a person using {}.", + "a example of a person doing {}.", + "a example of a person during {}.", + "a example of a person performing {}.", + "a example of a person practicing {}.", + "a demonstration of {}.", + "a demonstration of a person {}.", + "a demonstration of a person using {}.", + "a demonstration of a person doing {}.", + "a demonstration of a person during {}.", + "a demonstration of a person performing {}.", + "a demonstration of a person practicing {}.", +] + +mnist_templates = [ + 'a photo of the number: "{}".', +] + +format_templates = { + "cifar10": cifar10_templates, + "cifar100": cifar100_templates, + "caltech101": caltech101_templates, + "eurosat": eurosat_templates, + "flowers101": flowers101_templates, + "food101": food101_templates, + "kitti": kitti_templates, + "kinetics": kinetics_templates, + "mnist": mnist_templates, +} + + +def select_uninferenced_dataset(dataset): + uninferenced_dataset = Dataset(media_type=MediaElement) + for item in dataset: + if not any(isinstance(annotation, HashKey) for annotation in item.annotations): + uninferenced_dataset.put(item) + return uninferenced_dataset + + +def calculate_hamming(B1, B2): + """ + :param B1: vector [n] + :param B2: vector [r*n] + :return: hamming distance [r] + """ + return np.count_nonzero(B1 != B2, axis=1) diff --git a/src/datumaro/components/algorithms/hash_key_inference/prune.py b/src/datumaro/components/algorithms/hash_key_inference/prune.py new file mode 100644 index 0000000000..2c04e0e77c --- /dev/null +++ b/src/datumaro/components/algorithms/hash_key_inference/prune.py @@ -0,0 +1,344 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +import math +import random +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple + +import numpy as np +from sklearn.cluster import KMeans + +import datumaro.plugins.ndr as ndr +from datumaro.components.algorithms.hash_key_inference.base import HashInference +from datumaro.components.algorithms.hash_key_inference.hashkey_util import ( + calculate_hamming, + format_templates, + select_uninferenced_dataset, + templates, +) +from datumaro.components.annotation import HashKey, Label, LabelCategories +from datumaro.components.dataset import Dataset +from datumaro.components.dataset_base import DatasetItem + + +def match_num_item_for_cluster(ratio, dataset_len, cluster_num_item_list): + total_num_selected_item = math.ceil(dataset_len * ratio) + + cluster_weights = np.array(cluster_num_item_list) / sum(cluster_num_item_list) + norm_cluster_num_item_list = (cluster_weights * total_num_selected_item).astype(int) + remaining_items = total_num_selected_item - sum(norm_cluster_num_item_list) + + if remaining_items > 0: + zero_cluster_indexes = np.where(norm_cluster_num_item_list == 0)[0] + add_clust_dist = np.sort(cluster_weights[zero_cluster_indexes])[::-1][:remaining_items] + + for dist in set(add_clust_dist): + indices = np.where(cluster_weights == dist)[0] + for index in indices: + norm_cluster_num_item_list[index] += 1 + + elif remaining_items < 0: + diff_num_item_list = np.argsort(cluster_weights - norm_cluster_num_item_list) + for diff_idx in diff_num_item_list[: abs(remaining_items)]: + norm_cluster_num_item_list[diff_idx] -= 1 + + return norm_cluster_num_item_list.tolist() + + +class PruneBase(ABC): + @abstractmethod + def base( + self, + ratio: float, + num_centers: Optional[int], + labels: Optional[List[int]], + database_keys: Optional[np.ndarray], + item_list: List[DatasetItem], + source: Optional[Dataset], + ) -> Tuple[List[DatasetItem], Optional[Dict]]: + """It executes each method for pruning. + + Parameters: + ratio: How much to remain dataset after pruning. + num_centers: Number of centers for clustering. + labels: Label of one annotation for each datasetitem. + database_keys: Batch of the numpy formatted hash_key. + item_list: List of datasetitem of dataset. + source: Whole dataset. + Returns: + It returns a tuple of selected items and distance of each item and clusters. + """ + raise NotImplementedError + + +class RandomSelect(PruneBase): + """ + Select items randomly from the dataset. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + random.seed(0) + dataset_len = len(item_list) + num_selected_item = math.ceil(dataset_len * ratio) + random_indices = random.sample(range(dataset_len), num_selected_item) + selected_items = [item_list[idx] for idx in random_indices] + return selected_items, None + + +class Centroid(PruneBase): + """ + Select items through clustering with centers targeting the desired number. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + num_selected_centers = math.ceil(len(item_list) * ratio) + kmeans = KMeans(n_clusters=num_selected_centers, random_state=0) + clusters = kmeans.fit_predict(database_keys) + cluster_centers = kmeans.cluster_centers_ + cluster_ids = np.unique(clusters) + + selected_items = [] + dist_tuples = [] + for cluster_id in cluster_ids: + cluster_center = cluster_centers[cluster_id] + cluster_items_idx = np.where(clusters == cluster_id)[0] + num_selected_items = 1 + cluster_items = database_keys[cluster_items_idx,] + dist = calculate_hamming(cluster_center, cluster_items) + ind = np.argsort(dist) + item_idx_list = cluster_items_idx[ind] + for i, idx in enumerate(item_idx_list[:num_selected_items]): + selected_items.append(item_list[idx]) + dist_tuples.append( + (cluster_id, item_list[idx].id, item_list[idx].subset, dist[ind][i]) + ) + return selected_items, dist_tuples + + +class ClusteredRandom(PruneBase): + """ + Select items through clustering and choose randomly within each cluster. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + kmeans = KMeans(n_clusters=num_centers, random_state=0) + clusters = kmeans.fit_predict(database_keys) + cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) + + norm_cluster_num_item_list = match_num_item_for_cluster( + ratio, len(database_keys), cluster_num_item_list + ) + + selected_items = [] + random.seed(0) + for i, cluster_id in enumerate(cluster_ids): + cluster_items_idx = np.where(clusters == cluster_id)[0] + num_selected_items = norm_cluster_num_item_list[i] + random.shuffle(cluster_items_idx) + selected_items.extend(item_list[idx] for idx in cluster_items_idx[:num_selected_items]) + return selected_items, None + + +class QueryClust(PruneBase): + """ + Select items through clustering with inits that imply each label. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + center_dict = {i: None for i in range(1, num_centers)} + for item in item_list: + for anno in item.annotations: + if isinstance(anno, Label): + label_ = anno.label + if center_dict.get(label_) is None: + center_dict[label_] = item + if all(center_dict.values()): + break + + item_id_list = [item.id.split("/")[-1] for item in item_list] + centroids = [ + database_keys[item_id_list.index(i.id.split(":")[-1])] + for i in list(center_dict.values()) + if i + ] + kmeans = KMeans(n_clusters=num_centers, n_init=1, init=centroids, random_state=0) + + clusters = kmeans.fit_predict(database_keys) + cluster_centers = kmeans.cluster_centers_ + cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) + + norm_cluster_num_item_list = match_num_item_for_cluster( + ratio, len(database_keys), cluster_num_item_list + ) + + selected_items = [] + dist_tuples = [] + for i, cluster_id in enumerate(cluster_ids): + cluster_center = cluster_centers[cluster_id] + cluster_items_idx = np.where(clusters == cluster_id)[0] + num_selected_item = norm_cluster_num_item_list[i] + + cluster_items = database_keys[cluster_items_idx] + dist = calculate_hamming(cluster_center, cluster_items) + ind = np.argsort(dist) + item_idx_list = cluster_items_idx[ind] + for i, idx in enumerate(item_idx_list[:num_selected_item]): + selected_items.append(item_list[idx]) + dist_tuples.append( + (cluster_id, item_list[idx].id, item_list[idx].subset, dist[ind][i]) + ) + return selected_items, dist_tuples + + +class Entropy(PruneBase): + """ + Select items through clustering and choose them based on label entropy in each cluster. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + kmeans = KMeans(n_clusters=num_centers, random_state=0) + clusters = kmeans.fit_predict(database_keys) + + cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) + norm_cluster_num_item_list = match_num_item_for_cluster( + ratio, len(database_keys), cluster_num_item_list + ) + + selected_item_indexes = [] + for cluster_id, num_selected_item in zip(cluster_ids, norm_cluster_num_item_list): + cluster_items_idx = np.where(clusters == cluster_id)[0] + + cluster_classes = np.array(labels)[cluster_items_idx] + _, inv, cnts = np.unique(cluster_classes, return_inverse=True, return_counts=True) + weights = 1 / cnts + probs = weights[inv] + probs /= probs.sum() + + choices = np.random.choice(len(inv), size=num_selected_item, p=probs, replace=False) + selected_item_indexes.extend(cluster_items_idx[choices]) + + selected_items = np.array(item_list)[selected_item_indexes].tolist() + return selected_items, None + + +class NDRSelect(PruneBase): + """ + Select items based on NDR among each subset. + """ + + def base(self, ratio, num_centers, labels, database_keys, item_list, source): + subset_lists = list(source.subsets().keys()) + + selected_items = [] + for subset_ in subset_lists: + subset_len = len(source.get_subset(subset_)) + num_selected_subset_item = math.ceil(subset_len * (1 - ratio)) + ndr_result = ndr.NDR(source, working_subset=subset_, num_cut=num_selected_subset_item) + selected_items.extend(ndr_result.get_subset(subset_)) + + return selected_items, None + + +class Prune(HashInference): + def __init__( + self, + dataset: Dataset, + cluster_method: str = "random", + hash_type: str = "img", + ) -> None: + """ + Prune make a representative and manageable subset. + """ + self._dataset = dataset + self._cluster_method = cluster_method + self._hash_type = hash_type + + self._model = None + self._text_model = None + self._num_centers = None + + self._database_keys = None + self._item_list = [] + self._labels = [] + + self._prepare_data() + + def _prepare_data(self): + if self._hash_type == "txt": + category_dict = self._prompting() + + if self._cluster_method == "random": + self._item_list = list(self._dataset) + return + + datasets_to_infer = select_uninferenced_dataset(self._dataset) + datasets = self._compute_hash_key([self._dataset], [datasets_to_infer])[0] + + for category in datasets.categories().values(): + if isinstance(category, LabelCategories): + self._num_centers = len(category._indices.keys()) + + for item in datasets: + for annotation in item.annotations: + if isinstance(annotation, Label): + self._labels.append(annotation.label) + if isinstance(annotation, HashKey): + hash_key = annotation.hash_key + if self._hash_type == "txt": + inputs = category_dict.get(str(item.annotations[0].label)) + if isinstance(inputs, List): + inputs = " ".join(inputs) + hash_key_txt = self.text_model.infer_text(inputs).hash_key + hash_key = np.concatenate([hash_key, hash_key_txt]) + hash_key = np.unpackbits(hash_key, axis=-1) + if self._database_keys is None: + self._database_keys = hash_key.reshape(1, -1) + else: + self._database_keys = np.concatenate( + (self._database_keys, hash_key.reshape(1, -1)), axis=0 + ) + self._item_list.append(item) + + def _prompting(self): + category_dict = {} + detected_format = self._dataset.format + template = format_templates.get(detected_format, templates) + for label in list(self._dataset.categories().values())[0]._indices.keys(): + category_dict[label] = [temp.format(label) for temp in template] + return category_dict + + def get_pruned(self, ratio: float = 0.5) -> Dataset: + method = { + "random": RandomSelect, + "cluster_random": ClusteredRandom, + "centroid": Centroid, + "query_clust": QueryClust, + "entropy": Entropy, + "ndr": NDRSelect, + } + + prune_method = method[self._cluster_method]() + selected_items, dist_tuples = prune_method.base( + ratio=ratio, + num_centers=self._num_centers, + labels=self._labels, + database_keys=self._database_keys, + item_list=self._item_list, + source=self._dataset, + ) + + result_dataset = Dataset(media_type=self._dataset.media_type()) + result_dataset._source_path = self._dataset._source_path + result_dataset.define_categories(self._dataset.categories()) + for item in selected_items: + result_dataset.put(item) + + if dist_tuples: + for center, id_, subset_, d in dist_tuples: + log.info(f"item {id_} of subset {subset_} has distance {d} for cluster {center}") + + log.info(f"Pruned dataset with {ratio} from {len(self._dataset)} to {len(result_dataset)}") + return result_dataset diff --git a/tests/integration/cli/test_prune.py b/tests/integration/cli/test_prune.py new file mode 100644 index 0000000000..f8293d18f6 --- /dev/null +++ b/tests/integration/cli/test_prune.py @@ -0,0 +1,117 @@ +import os.path as osp +from collections import Counter + +import numpy as np +import pytest + +from datumaro.cli.util.project import parse_dataset_pathspec +from datumaro.components.annotation import Caption, Label +from datumaro.components.dataset import Dataset +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.util.scope import scope_add, scoped + +from ...requirements import Requirements, mark_requirement + +from tests.utils.test_utils import TestCaseHelper, TestDir +from tests.utils.test_utils import run_datum as run + + +class PruneTest: + @pytest.fixture + def fxt_dataset(self) -> Dataset: + train_img = np.full((5, 5, 3), 255, dtype=np.uint8) + test_img = np.full((5, 5, 3), 0, dtype=np.uint8) + + dataset = Dataset.from_iterable( + [ + DatasetItem( + id=1, + subset="train", + media=Image.from_numpy(data=train_img), + annotations=[Label(1, id=1), Caption("cat")], + ), + DatasetItem( + id=2, + subset="train", + media=Image.from_numpy(data=train_img), + annotations=[Label(1, id=1), Caption("cat")], + ), + DatasetItem( + id=3, + subset="test", + media=Image.from_numpy(data=test_img), + annotations=[Label(2, id=2), Caption("dog")], + ), + DatasetItem( + id=4, + subset="test", + media=Image.from_numpy(data=test_img), + annotations=[Label(2, id=2), Caption("dog")], + ), + ], + categories=["1", "2"], + ) + return dataset + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_prune_dataset_w_target(self, helper_tc: TestCaseHelper, fxt_dataset): + test_dir = scope_add(TestDir()) + proj_dir = osp.join(test_dir, "proj") + dataset_url = osp.join(test_dir, "dataset") + + fxt_dataset.export(dataset_url, "datumaro", save_media=True) + + run(helper_tc, "project", "create", "-o", proj_dir) + run(helper_tc, "project", "import", "-p", proj_dir, "-f", "datumaro", dataset_url) + run( + helper_tc, + "prune", + "source-1", + "-m", + "random", + "-r", + "0.5", + "-p", + proj_dir, + ) + + parsed_dataset = parse_dataset_pathspec(proj_dir) + result_subsets = [item.subset for item in parsed_dataset] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_prune_dataset_wo_target(self, helper_tc: TestCaseHelper, fxt_dataset): + test_dir = scope_add(TestDir()) + proj_dir = osp.join(test_dir, "proj") + dataset_url = osp.join(test_dir, "dataset") + + fxt_dataset.export(dataset_url, "datumaro", save_media=True) + + run(helper_tc, "project", "create", "-o", proj_dir) + run(helper_tc, "project", "import", "-p", proj_dir, "-f", "datumaro", dataset_url) + run(helper_tc, "prune", "-m", "random", "-r", "0.5", "-p", proj_dir) + + parsed_dataset = parse_dataset_pathspec(proj_dir) + result_subsets = [item.subset for item in parsed_dataset] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_prune_wo_overwrite(self, helper_tc: TestCaseHelper, fxt_dataset): + test_dir = scope_add(TestDir()) + proj_dir = osp.join(test_dir, "proj") + dataset_url = osp.join(test_dir, "dataset") + dst_dir = osp.join(test_dir, "result") + + fxt_dataset.export(dataset_url, "datumaro", save_media=True) + + run(helper_tc, "project", "create", "-o", proj_dir) + run(helper_tc, "project", "import", "-p", proj_dir, "-f", "datumaro", dataset_url) + run(helper_tc, "prune", "-m", "random", "-r", "0.5", "-p", proj_dir, "-o", dst_dir) + + parsed_dataset = parse_dataset_pathspec(dst_dir) + result_subsets = [item.subset for item in parsed_dataset] + assert Counter(result_subsets) == {"test": 1, "train": 1} diff --git a/tests/unit/test_explorer.py b/tests/unit/test_explorer.py index d3e646d155..2cf012b7cf 100644 --- a/tests/unit/test_explorer.py +++ b/tests/unit/test_explorer.py @@ -5,11 +5,11 @@ import numpy as np +from datumaro.components.algorithms.hash_key_inference.explorer import Explorer from datumaro.components.annotation import Caption, Label from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.errors import MediaTypeError -from datumaro.components.explorer import Explorer from datumaro.components.media import Image from datumaro.plugins.data_formats.datumaro.exporter import DatumaroExporter diff --git a/tests/unit/test_prune.py b/tests/unit/test_prune.py new file mode 100644 index 0000000000..5143a6dad5 --- /dev/null +++ b/tests/unit/test_prune.py @@ -0,0 +1,250 @@ +from collections import Counter +from functools import partial + +import numpy as np +import pytest + +from datumaro.components.algorithms.hash_key_inference.prune import ( + Prune, + match_num_item_for_cluster, +) +from datumaro.components.annotation import Caption, Label +from datumaro.components.dataset import Dataset +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.plugins.data_formats.datumaro.exporter import DatumaroExporter + +from ..requirements import Requirements, mark_requirement + + +class PruneTest: + @pytest.fixture + def fxt_dataset(self) -> Dataset: + train_img = np.full((5, 5, 3), 255, dtype=np.uint8) + test_img = np.full((5, 5, 3), 0, dtype=np.uint8) + + dataset = Dataset.from_iterable( + [ + DatasetItem( + id=1, + subset="train", + media=Image.from_numpy(data=train_img), + annotations=[Label(1, id=1), Caption("cat")], + ), + DatasetItem( + id=2, + subset="train", + media=Image.from_numpy(data=train_img), + annotations=[Label(1, id=1), Caption("cat")], + ), + DatasetItem( + id=3, + subset="test", + media=Image.from_numpy(data=test_img), + annotations=[Label(2, id=2), Caption("dog")], + ), + DatasetItem( + id=4, + subset="test", + media=Image.from_numpy(data=test_img), + annotations=[Label(2, id=2), Caption("dog")], + ), + ], + categories=["1", "2"], + ) + return dataset + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_match_num_item_for_cluster(self): + """ """ + ratio = 0.5 + total_num_items = 100 + cluster_num_items = [20, 30, 15, 10, 25] + + result = match_num_item_for_cluster(ratio, total_num_items, cluster_num_items) + + # Assert the expected result based on the given inputs + expected_result = [10, 15, 7, 5, 12] + assert result == expected_result + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_random(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with random. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as random to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="random") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_clustered_random(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with clustered random. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as clustered_random to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="cluster_random") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_centroid(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with centroid. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as centroid to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="centroid") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_query_clust_img_hash(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with clustering with query through image hash. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as query_clust to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="query_clust") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_query_clust_txt_hash(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with clustering with query through text hash. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as query_clust to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="query_clust", hash_type="txt") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_entropy(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with entropy. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as entropy to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="entropy") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1} + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_prune_ndr(self, fxt_dataset, test_dir): + """ + Description: + Check that pruned subset with ndr. + + Input data: + Dataset with train and test subset that each datasetitem consists of same images. + + Expected results: + Pruned dataset that each subset contains one datasetitem. + + Steps + 1. Prepare dataset with each subset contains same images. + 2. Set Prune and try get_pruned set method as ndr to extract representative subset. + 3. Check whether each subset contains one datasetitem. + """ + converter = partial(DatumaroExporter.convert, save_media=True) + converter(fxt_dataset, test_dir) + imported_dataset = Dataset.import_from(test_dir, "datumaro") + prune = Prune(imported_dataset, cluster_method="ndr") + + result = prune.get_pruned(0.5) + result_subsets = [item.subset for item in result] + assert Counter(result_subsets) == {"test": 1, "train": 1}