Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prune API #1058

Merged
merged 17 commits into from
Jul 5, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1059>)
- Migrate DVC v3.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1072>)
- Add Prune API
(<https://github.com/openvinotoolkit/datumaro/pull/1058>)

### Enhancements
- Enhance import performance for built-in plugins
Expand Down
3 changes: 3 additions & 0 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ tabulate
# Model inference launcher from the dedicated inference server
ovmsclient
tritonclient[all]

# prune
scikit-learn
4 changes: 3 additions & 1 deletion src/datumaro/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
info,
merge,
patch,
prune,
stats,
transform,
validate,
Expand All @@ -36,11 +37,12 @@ def get_non_project_commands():
("dinfo", info, "Print dataset info"),
("download", download, "Download a publicly available dataset"),
("explain", explain, "Run Explainable AI algorithm for model"),
("explore", explore, "Explore similar datasetitems of query"),
("filter", filter, "Filter dataset items"),
("generate", generate, "Generate synthetic dataset"),
("merge", merge, "Merge datasets"),
("patch", patch, "Update dataset from another one"),
("explore", explore, "Explore similar datasetitems of query"),
("prune", prune, "Prune dataset"),
("stats", stats, "Compute dataset statistics"),
("transform", transform, "Modify dataset items"),
("validate", validate, "Validate dataset"),
Expand Down
121 changes: 121 additions & 0 deletions src/datumaro/cli/commands/prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import argparse
import logging as log

from datumaro.components.errors import ProjectNotFoundError
from datumaro.components.prune import Prune
from datumaro.util.scope import scope_add, scoped

from ..util import MultilineFormatter
from ..util.project import load_project, parse_full_revpath


def build_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(
help="Prune dataset and make a representative subset",
description="""
Apply data pruning to a dataset.|n
The command can be useful if you have to extract representative subset.
|n
The current project (-p/--project) is used as a context for plugins
and models. It is used when there is a dataset path in target.
When not specified, the current project's working tree is used.|n
|n
By default, datasets are updated in-place. The '-o/--output-dir'
option can be used to specify another output directory. When
updating in-place, use the '--overwrite' parameter (in-place
updates fail by default to prevent data loss), unless a project
target is modified.|n
|n
The command can be applied to a dataset or a project build target,
a stage or the combined 'project' target, in which case all the
targets will be affected.|n
|n
Examples:|n
- Prune dataset with selecting random and ratio 80%:|n
|s|s%(prog)s -m random -r 0.8|n
- Prune dataset with clustering in image hash and ratio 50%:|n
|s|s%(prog)s -m query_clust -h img -r 0.5|
- Prune dataset based on entropy with clustering in image hash and ratio 50%:|n
|s|s%(prog)s -m entropy -h img -r 0.5|
""",
formatter_class=MultilineFormatter,
)

parser.add_argument("target", nargs="?", help="Target dataset revpath (default: project)")
parser.add_argument("-m", "--method", dest="method", help="Method to apply to the dataset")
parser.add_argument(
"-r", "--ratio", type=float, dest="ratio", help="How much to remain dataset after pruning"
)
parser.add_argument(
"--hash-type",
type=str,
dest="hash_type",
default="img",
help="Hashtype to extract feature from data information between image and text(label)",
)
sooahleex marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
"-p",
"--project",
dest="project_dir",
help="Directory of the project to operate on (default: current dir)",
)

parser.add_argument(
"-o",
"--output-dir",
dest="dst_dir",
help="""
Output directory. Can be omitted for main project targets
(i.e. data sources and the 'project' target, but not
intermediate stages) and dataset targets.
If not specified, the results will be saved inplace.
""",
)
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite existing files in the save directory"
)
parser.set_defaults(command=prune_command)

return parser


def get_sensitive_args():
return {
prune_command: [
"target",
"method",
"ratio",
"hash_type",
"project_dir",
"dst_dir",
]
}


@scoped
def prune_command(args):
project = None
try:
project = scope_add(load_project(args.project_dir))
except ProjectNotFoundError:
if args.project_dir:
raise

targets = [args.target] if args.target else list(project.working_tree.sources)

source_dataset = [parse_full_revpath(target, project)[0] for target in targets][0]

prune = Prune(source_dataset, cluster_method=args.method, hash_type=args.hash_type)

source_dataset.save(source_dataset.data_path, save_media=True, save_hashkey_meta=True)

result = prune.get_pruned(args.ratio)

dst_dir = args.dst_dir or source_dataset.data_path
result.save(dst_dir, save_media=True)

log.info("Results have been saved to '%s'" % dst_dir)
64 changes: 28 additions & 36 deletions src/datumaro/components/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,40 @@
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import DatumaroError, MediaTypeError
from datumaro.components.media import MediaElement
from datumaro.plugins.explorer import ExplorerLauncher
from datumaro.util.hashkey_util import calculate_hamming, select_uninferenced_dataset


def calculate_hamming(B1, B2):
"""
:param B1: vector [n]
:param B2: vector [r*n]
:return: hamming distance [r]
"""
return np.count_nonzero(B1 != B2, axis=1)
class HashInference:
def __init__(self, *datasets: Sequence[Dataset]) -> None:
pass

@property
def model(self):
if self._model is None:
self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32")
return self._model

@property
def text_model(self):
if self._text_model is None:
self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32")
return self._text_model

def select_uninferenced_dataset(dataset):
uninferenced_dataset = Dataset(media_type=MediaElement)
for item in dataset:
if not any(isinstance(annotation, HashKey) for annotation in item.annotations):
uninferenced_dataset.put(item)
return uninferenced_dataset
def _compute_hash_key(self, datasets, datasets_to_infer):
for dataset_to_infer in datasets_to_infer:
if dataset_to_infer:
dataset_to_infer.run_model(self.model, append_annotation=True)
for dataset, dataset_to_infer in zip(datasets, datasets_to_infer):
updated_items = [
dataset.get(item.id, item.subset).wrap(annotations=item.annotations)
for item in dataset_to_infer
]
dataset.update(updated_items)
return datasets
sooahleex marked this conversation as resolved.
Show resolved Hide resolved


class Explorer:
class Explorer(HashInference):
def __init__(
self,
*datasets: Sequence[Dataset],
Expand All @@ -54,7 +66,7 @@ def __init__(
item_list = []

datasets_to_infer = [select_uninferenced_dataset(dataset) for dataset in datasets]
datasets = self.compute_hash_key(datasets, datasets_to_infer)
datasets = self._compute_hash_key(datasets, datasets_to_infer)

for dataset in datasets:
for item in dataset:
Expand All @@ -75,26 +87,6 @@ def __init__(
self._database_keys = np.stack(database_keys, axis=0)
self._item_list = item_list

@property
def model(self):
if self._model is None:
self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32")
return self._model

@property
def text_model(self):
if self._text_model is None:
self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32")
return self._text_model

def compute_hash_key(self, datasets, datasets_to_infer):
for dataset in datasets_to_infer:
if len(dataset) > 0:
dataset.run_model(self.model, append_annotation=True)
for dataset, dataset_to_infer in zip(datasets, datasets_to_infer):
dataset.update(dataset_to_infer)
return datasets

def explore_topk(
self,
query: Union[DatasetItem, str],
Expand Down
Loading