From 7530ab44443a47fa33cc66e1e8cd44dd3912f87c Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Wed, 17 Apr 2024 14:00:35 +0900 Subject: [PATCH] Fix ambiguous coco format detector (#1442) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary Before, all the coco datasets such as panoptic, stuff, caption are detected as just `coco`. With this PR, a specific coco task is detected. For instance, ```code formats = dm.Environment().detect_dataset(path="./tests/assets/coco_dataset/coco_instances") ``` returns `coco_instances`. If the dataset directory contains multiple task types of COCO, it looks like ```code formats = dm.Environment().detect_dataset(path="./tests/assets/coco_dataset/coco") ``` returns `coco`. So we can choose one task for importing operation. We here remain the original `coco` to import all annotation files like below. ![image](https://github.com/openvinotoolkit/datumaro/assets/89109581/9761b244-de21-4d6c-8f6e-76dcfda2bb96) *note*. Previously, our coco importer can load any annotation filename like `{unknowntask}_{subsetname}`, but with this PR, it is restricted to import `{instances,panoptic,caption,labels,...}_{subsetname}`. ### How to test ### Checklist - [ ] I have added unit tests to cover my changes.​ - [ ] I have added integration tests to cover my changes.​ - [x] I have added the description of my changes into [CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​ - [ ] I have updated the [documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs) accordingly ### License - [ ] I submit _my code changes_ under the same [MIT License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. - [ ] I have updated the license header for each file (see an example below). ```python # Copyright (C) 2024 Intel Corporation # # SPDX-License-Identifier: MIT ``` --- CHANGELOG.md | 2 + .../plugins/data_formats/coco/importer.py | 45 +++++++++++-------- tests/unit/test_coco_format.py | 32 +++++++++++-- tests/unit/test_dataset.py | 4 +- 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbdcd4a0f0..b875950731 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () ### Enhancements +- Fix ambiguous COCO format detector + () ### Bug fixes diff --git a/src/datumaro/plugins/data_formats/coco/importer.py b/src/datumaro/plugins/data_formats/coco/importer.py index 90aab5f781..423517cb10 100644 --- a/src/datumaro/plugins/data_formats/coco/importer.py +++ b/src/datumaro/plugins/data_formats/coco/importer.py @@ -56,18 +56,21 @@ def detect( cls, context: FormatDetectionContext, ) -> FormatDetectionConfidence: - # The `coco` format is inherently ambiguous with `coco_instances`, - # `coco_stuff`, etc. To remove the ambiguity (and thus make it possible - # to use autodetection with the COCO dataset), disable autodetection - # for the single-task formats. - if len(cls._TASKS) == 1: + num_tasks = 0 + for task in cls._TASKS.keys(): + try: + context.require_files(f"annotations/{task.name}_*{cls._ANNO_EXT}") + num_tasks += 1 + except Exception: + pass + if num_tasks > 1: + log.warning( + "Multiple COCO tasks are detected. The detected format will be `coco` instead." + ) + return FormatDetectionConfidence.MEDIUM + else: context.raise_unsupported() - with context.require_any(): - for task in cls._TASKS.keys(): - with context.alternative(): - context.require_file(f"annotations/{task.name}_*{cls._ANNO_EXT}") - def __call__(self, path, stream: bool = False, **extra_params): subsets = self.find_sources(path) @@ -140,8 +143,6 @@ def detect_coco_task(filename): subsets = {} for subset_path in subset_paths: ann_type = detect_coco_task(osp.basename(subset_path)) - if ann_type is None and len(cls._TASKS) == 1: - ann_type = list(cls._TASKS)[0] if ann_type not in cls._TASKS: log.warning( @@ -175,32 +176,40 @@ class CocoImageInfoImporter(CocoImporter): _TASK = CocoTask.image_info _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} + @classmethod + def detect( + cls, + context: FormatDetectionContext, + ) -> FormatDetectionConfidence: + context.require_file(f"annotations/{cls._TASK.name}_*{cls._ANNO_EXT}") + return FormatDetectionConfidence.LOW + -class CocoCaptionsImporter(CocoImporter): +class CocoCaptionsImporter(CocoImageInfoImporter): _TASK = CocoTask.captions _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} -class CocoInstancesImporter(CocoImporter): +class CocoInstancesImporter(CocoImageInfoImporter): _TASK = CocoTask.instances _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} -class CocoPersonKeypointsImporter(CocoImporter): +class CocoPersonKeypointsImporter(CocoImageInfoImporter): _TASK = CocoTask.person_keypoints _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} -class CocoLabelsImporter(CocoImporter): +class CocoLabelsImporter(CocoImageInfoImporter): _TASK = CocoTask.labels _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} -class CocoPanopticImporter(CocoImporter): +class CocoPanopticImporter(CocoImageInfoImporter): _TASK = CocoTask.panoptic _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} -class CocoStuffImporter(CocoImporter): +class CocoStuffImporter(CocoImageInfoImporter): _TASK = CocoTask.stuff _TASKS = {_TASK: CocoImporter._TASKS[_TASK]} diff --git a/tests/unit/test_coco_format.py b/tests/unit/test_coco_format.py index 3e8b51e48f..750a38836d 100644 --- a/tests/unit/test_coco_format.py +++ b/tests/unit/test_coco_format.py @@ -10,7 +10,6 @@ from copy import deepcopy from functools import partial from io import StringIO -from itertools import product from unittest import TestCase, skip import numpy as np @@ -54,7 +53,6 @@ CocoStuffExporter, ) from datumaro.plugins.data_formats.coco.format import CocoPath -from datumaro.plugins.data_formats.coco.importer import CocoImporter from datumaro.util import dump_json_file, parse_json_file from ..requirements import Requirements, mark_requirement @@ -232,6 +230,10 @@ def test_can_import_instances(self, format, subset, path, stream, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_instances_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -415,6 +417,10 @@ def test_can_import_captions(self, format, subset, path, stream, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_captions_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -501,6 +507,10 @@ def test_can_import_labels(self, format, subset, path, stream, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_labels_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -662,6 +672,10 @@ def test_can_import_keypoints(self, format, subset, path, stream, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_keypoints_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -808,6 +822,10 @@ def test_can_import_image_info(self, stream, format, subset, path, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_image_info_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -911,6 +929,10 @@ def test_can_import_panoptic(self, format, subset, path, stream, helper_tc): ) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_panoptic_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -1075,6 +1097,10 @@ def test_can_import_stuff(self, format, subset, path, stream, helper_tc): check_is_stream(dataset, stream) compare_datasets(helper_tc, expected, dataset, require_media=True) + @skip( + "COCO format is required to specify the task in annotation file " + " for resolving ambiguity problem." + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @pytest.mark.parametrize("stream", [True, False]) def test_can_import_stuff_with_any_annotation_filename(self, stream, test_dir, helper_tc): @@ -1148,7 +1174,7 @@ def test_can_detect(self, subdir): dataset_dir = osp.join(DUMMY_DATASET_DIR, subdir) detected_formats = env.detect_dataset(dataset_dir) - assert [CocoImporter.NAME] == detected_formats + assert [subdir] == detected_formats @mark_requirement(Requirements.DATUM_673) @pytest.mark.parametrize( diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index e2dc7fe0ce..0ed8c25584 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -305,11 +305,11 @@ def test_can_detect_with_nested_folder_and_multiply_matches(self): with TestDir() as test_dir: dataset_path = osp.join(test_dir, "a", "b") - dataset.export(dataset_path, "coco", save_media=True) + dataset.export(dataset_path, "coco_labels", save_media=True) detected_format = Dataset.detect(test_dir, depth=2) - self.assertEqual("coco", detected_format) + self.assertEqual("coco_labels", detected_format) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_cannot_detect_for_non_existent_path(self):