diff --git a/src/datumaro/plugins/data_formats/common_semantic_segmentation.py b/src/datumaro/plugins/data_formats/common_semantic_segmentation.py index 4e9f55f625..7845ffc406 100644 --- a/src/datumaro/plugins/data_formats/common_semantic_segmentation.py +++ b/src/datumaro/plugins/data_formats/common_semantic_segmentation.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT import errno -import glob import os.path as osp from typing import List, Optional @@ -69,11 +68,11 @@ def __init__( self._image_prefix = image_prefix self._mask_prefix = mask_prefix - meta_file = glob.glob(osp.join(path, "**", DATASET_META_FILE), recursive=True) - if is_meta_file(meta_file[0]): - self._root_dir = osp.dirname(meta_file[0]) + meta_file = osp.join(path, DATASET_META_FILE) + if is_meta_file(meta_file): + self._root_dir = osp.dirname(meta_file) - label_map = parse_meta_file(meta_file[0]) + label_map = parse_meta_file(meta_file) self._categories = make_categories(label_map) else: raise FileNotFoundError(errno.ENOENT, "Dataset meta info file was not found", path) @@ -163,11 +162,10 @@ def build_cmdline_parser(cls, **kwargs): @classmethod def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence: - path = context.require_file(f"**/{DATASET_META_FILE}") - path = osp.dirname(path) + context.require_file(DATASET_META_FILE) - context.require_file(osp.join(path, CommonSemanticSegmentationPath.IMAGES_DIR, "**", "*")) - context.require_file(osp.join(path, CommonSemanticSegmentationPath.MASKS_DIR, "**", "*")) + context.require_file(osp.join(CommonSemanticSegmentationPath.IMAGES_DIR, "**", "*")) + context.require_file(osp.join(CommonSemanticSegmentationPath.MASKS_DIR, "**", "*")) return FormatDetectionConfidence.MEDIUM diff --git a/tests/unit/data_formats/conftest.py b/tests/unit/data_formats/conftest.py index c4c8c32f95..12351ee037 100644 --- a/tests/unit/data_formats/conftest.py +++ b/tests/unit/data_formats/conftest.py @@ -10,6 +10,8 @@ from datumaro import Dataset +from tests.utils.test_utils import TestDir + @pytest.fixture def fxt_dummy_dataset(): @@ -35,12 +37,12 @@ def fxt_export_kwargs(): @pytest.fixture def fxt_dataset_dir_with_subset_dirs(test_dir: str, request: pytest.FixtureRequest): fxt_dataset_dir = request.param + with TestDir(f"{test_dir}_with_subsets") as new_test_dir: + for subset in ["train", "val", "test"]: + dst = os.path.join(new_test_dir, subset) + shutil.copytree(fxt_dataset_dir, dst) - for subset in ["train", "val", "test"]: - dst = os.path.join(test_dir, subset) - shutil.copytree(fxt_dataset_dir, dst) - - yield test_dir + yield new_test_dir @pytest.fixture diff --git a/tests/unit/data_formats/test_common_semantic_segmentation_format.py b/tests/unit/data_formats/test_common_semantic_segmentation_format.py index b7297e4885..1d03e9ba01 100644 --- a/tests/unit/data_formats/test_common_semantic_segmentation_format.py +++ b/tests/unit/data_formats/test_common_semantic_segmentation_format.py @@ -2,8 +2,10 @@ # # SPDX-License-Identifier: MIT +import os +import shutil from collections import OrderedDict -from typing import Any, Dict, Optional +from typing import Any, Dict import numpy as np import pytest @@ -11,6 +13,7 @@ from datumaro.components.annotation import Mask from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatasetImportError from datumaro.components.media import Image from datumaro.plugins.data_formats.common_semantic_segmentation import ( CommonSemanticSegmentationImporter, @@ -143,6 +146,40 @@ def test_can_import( fxt_dataset_dir, fxt_expected_dataset, fxt_import_kwargs, request ) + @pytest.mark.parametrize( + ["fxt_dataset_dir", "fxt_expected_dataset", "fxt_import_kwargs"], + [ + (DUMMY_DATASET_DIR, "fxt_dataset", {}), + ( + DUMMY_NON_STANDARD_DATASET_DIR, + "fxt_non_standard_dataset", + {"image_prefix": "image_", "mask_prefix": "gt_"}, + ), + ], + indirect=["fxt_expected_dataset"], + ids=IDS, + ) + def test_cannot_import_nested( + self, + fxt_dataset_dir: str, + fxt_expected_dataset: Dataset, + fxt_import_kwargs: Dict[str, Any], + request: pytest.FixtureRequest, + test_dir: str, + ): + shutil.copytree(fxt_dataset_dir, test_dir, dirs_exist_ok=True) + subdir_name = "subdir" + subdir = os.path.join(test_dir, subdir_name) + os.makedirs(subdir) + for _file in os.listdir(test_dir): + if _file != subdir_name: + file_path = os.path.join(test_dir, _file) + shutil.move(file_path, subdir) + with pytest.raises(DatasetImportError) as exc_info: + super().test_can_import(test_dir, fxt_expected_dataset, fxt_import_kwargs, request) + assert exc_info.value.__cause__ is not None + assert isinstance(exc_info.value.__cause__, FileNotFoundError) + class CommonSemanticSegmentationWithSubsetDirsImporterTest(TestDataFormatBase): IMPORTER = CommonSemanticSegmentationWithSubsetDirsImporter