Skip to content

Commit

Permalink
Fix ambiguous coco format detector (openvinotoolkit#1442)
Browse files Browse the repository at this point in the history
<!-- Contributing guide:
https://github.com/openvinotoolkit/datumaro/blob/develop/CONTRIBUTING.md
-->

### Summary

Before, all the coco datasets such as panoptic, stuff, caption are
detected as just `coco`.
With this PR, a specific coco task is detected.
For instance,
```code
formats = dm.Environment().detect_dataset(path="./tests/assets/coco_dataset/coco_instances")
```
returns `coco_instances`.

If the dataset directory contains multiple task types of COCO, it looks
like
```code
formats = dm.Environment().detect_dataset(path="./tests/assets/coco_dataset/coco")
```
returns `coco`.

So we can choose one task for importing operation.
We here remain the original `coco` to import all annotation files like
below.

![image](https://github.com/openvinotoolkit/datumaro/assets/89109581/9761b244-de21-4d6c-8f6e-76dcfda2bb96)


*note*. Previously, our coco importer can load any annotation filename
like `{unknowntask}_{subsetname}`, but with this PR, it is restricted to
import `{instances,panoptic,caption,labels,...}_{subsetname}`.

<!--
Resolves openvinotoolkit#111 and openvinotoolkit#222.
Depends on openvinotoolkit#1000 (for series of dependent commits).

This PR introduces this capability to make the project better in this
and that.

- Added this feature
- Removed that feature
- Fixed the problem openvinotoolkit#1234
-->

### How to test
<!-- Describe the testing procedure for reviewers, if changes are
not fully covered by unit tests or manual testing can be complicated.
-->

### Checklist
<!-- Put an 'x' in all the boxes that apply -->
- [ ] I have added unit tests to cover my changes.​
- [ ] I have added integration tests to cover my changes.​
- [x] I have added the description of my changes into
[CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​
- [ ] I have updated the
[documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs)
accordingly

### License

- [ ] I submit _my code changes_ under the same [MIT
License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE)
that covers the project.
  Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example
below).

```python
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
```
  • Loading branch information
wonjuleee authored Apr 17, 2024
1 parent 6c048ea commit 7530ab4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1422>)

### Enhancements
- Fix ambiguous COCO format detector
(<https://github.com/openvinotoolkit/datumaro/pull/1442>)

### Bug fixes

Expand Down
45 changes: 27 additions & 18 deletions src/datumaro/plugins/data_formats/coco/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,21 @@ def detect(
cls,
context: FormatDetectionContext,
) -> FormatDetectionConfidence:
# The `coco` format is inherently ambiguous with `coco_instances`,
# `coco_stuff`, etc. To remove the ambiguity (and thus make it possible
# to use autodetection with the COCO dataset), disable autodetection
# for the single-task formats.
if len(cls._TASKS) == 1:
num_tasks = 0
for task in cls._TASKS.keys():
try:
context.require_files(f"annotations/{task.name}_*{cls._ANNO_EXT}")
num_tasks += 1
except Exception:
pass
if num_tasks > 1:
log.warning(
"Multiple COCO tasks are detected. The detected format will be `coco` instead."
)
return FormatDetectionConfidence.MEDIUM
else:
context.raise_unsupported()

with context.require_any():
for task in cls._TASKS.keys():
with context.alternative():
context.require_file(f"annotations/{task.name}_*{cls._ANNO_EXT}")

def __call__(self, path, stream: bool = False, **extra_params):
subsets = self.find_sources(path)

Expand Down Expand Up @@ -140,8 +143,6 @@ def detect_coco_task(filename):
subsets = {}
for subset_path in subset_paths:
ann_type = detect_coco_task(osp.basename(subset_path))
if ann_type is None and len(cls._TASKS) == 1:
ann_type = list(cls._TASKS)[0]

if ann_type not in cls._TASKS:
log.warning(
Expand Down Expand Up @@ -175,32 +176,40 @@ class CocoImageInfoImporter(CocoImporter):
_TASK = CocoTask.image_info
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}

@classmethod
def detect(
cls,
context: FormatDetectionContext,
) -> FormatDetectionConfidence:
context.require_file(f"annotations/{cls._TASK.name}_*{cls._ANNO_EXT}")
return FormatDetectionConfidence.LOW


class CocoCaptionsImporter(CocoImporter):
class CocoCaptionsImporter(CocoImageInfoImporter):
_TASK = CocoTask.captions
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}


class CocoInstancesImporter(CocoImporter):
class CocoInstancesImporter(CocoImageInfoImporter):
_TASK = CocoTask.instances
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}


class CocoPersonKeypointsImporter(CocoImporter):
class CocoPersonKeypointsImporter(CocoImageInfoImporter):
_TASK = CocoTask.person_keypoints
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}


class CocoLabelsImporter(CocoImporter):
class CocoLabelsImporter(CocoImageInfoImporter):
_TASK = CocoTask.labels
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}


class CocoPanopticImporter(CocoImporter):
class CocoPanopticImporter(CocoImageInfoImporter):
_TASK = CocoTask.panoptic
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}


class CocoStuffImporter(CocoImporter):
class CocoStuffImporter(CocoImageInfoImporter):
_TASK = CocoTask.stuff
_TASKS = {_TASK: CocoImporter._TASKS[_TASK]}
32 changes: 29 additions & 3 deletions tests/unit/test_coco_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from copy import deepcopy
from functools import partial
from io import StringIO
from itertools import product
from unittest import TestCase, skip

import numpy as np
Expand Down Expand Up @@ -54,7 +53,6 @@
CocoStuffExporter,
)
from datumaro.plugins.data_formats.coco.format import CocoPath
from datumaro.plugins.data_formats.coco.importer import CocoImporter
from datumaro.util import dump_json_file, parse_json_file

from ..requirements import Requirements, mark_requirement
Expand Down Expand Up @@ -232,6 +230,10 @@ def test_can_import_instances(self, format, subset, path, stream, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_instances_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -415,6 +417,10 @@ def test_can_import_captions(self, format, subset, path, stream, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_captions_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -501,6 +507,10 @@ def test_can_import_labels(self, format, subset, path, stream, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_labels_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -662,6 +672,10 @@ def test_can_import_keypoints(self, format, subset, path, stream, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_keypoints_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -808,6 +822,10 @@ def test_can_import_image_info(self, stream, format, subset, path, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_image_info_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -911,6 +929,10 @@ def test_can_import_panoptic(self, format, subset, path, stream, helper_tc):
)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_panoptic_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -1075,6 +1097,10 @@ def test_can_import_stuff(self, format, subset, path, stream, helper_tc):
check_is_stream(dataset, stream)
compare_datasets(helper_tc, expected, dataset, require_media=True)

@skip(
"COCO format is required to specify the task in annotation file "
" for resolving ambiguity problem."
)
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("stream", [True, False])
def test_can_import_stuff_with_any_annotation_filename(self, stream, test_dir, helper_tc):
Expand Down Expand Up @@ -1148,7 +1174,7 @@ def test_can_detect(self, subdir):
dataset_dir = osp.join(DUMMY_DATASET_DIR, subdir)

detected_formats = env.detect_dataset(dataset_dir)
assert [CocoImporter.NAME] == detected_formats
assert [subdir] == detected_formats

@mark_requirement(Requirements.DATUM_673)
@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ def test_can_detect_with_nested_folder_and_multiply_matches(self):

with TestDir() as test_dir:
dataset_path = osp.join(test_dir, "a", "b")
dataset.export(dataset_path, "coco", save_media=True)
dataset.export(dataset_path, "coco_labels", save_media=True)

detected_format = Dataset.detect(test_dir, depth=2)

self.assertEqual("coco", detected_format)
self.assertEqual("coco_labels", detected_format)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_cannot_detect_for_non_existent_path(self):
Expand Down

0 comments on commit 7530ab4

Please sign in to comment.