diff --git a/CHANGELOG.md b/CHANGELOG.md index 161b2d54e9..5721c979be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## \[Q3 2024 Release 1.9.0\] +## \[Q4 2024 Release 1.9.1\] +### New features + +### Enhancements +- Support multiple labels for kaggle format + () + +### Bug fixes + +## Q3 2024 Release 1.9.0 ### New features - Add a new CLI command: datum format () diff --git a/src/datumaro/plugins/data_formats/kaggle/base.py b/src/datumaro/plugins/data_formats/kaggle/base.py index d21b1434c1..06d2ef9a15 100644 --- a/src/datumaro/plugins/data_formats/kaggle/base.py +++ b/src/datumaro/plugins/data_formats/kaggle/base.py @@ -77,13 +77,31 @@ def _parse_bbox_coords(self, bbox_str): # expected to output [x1, y1, x2, y2] return [float(coord.strip()) for coord in coords] - def _load_annotations(self, datas: list, indices: Dict[str, int], bbox_flag: bool): + def _load_annotations( + self, datas: list, indices: Dict[str, Union[int, Dict[str, int]]], bbox_flag: bool + ): if "label" in indices: - label_name = str(datas[indices["label"]]) - label, cat = self._label_cat.find(label_name) - if not cat: - self._label_cat.add(label_name) - label, _ = self._label_cat.find(label_name) + label_indices = indices["label"] + if isinstance(label_indices, dict): + labels = [] + list_values = datas[1:] + index_to_label = {v: k for k, v in label_indices.items()} + present_labels = [ + index_to_label[i + 1] for i, value in enumerate(list_values) if value == "1" + ] + + for label_name in present_labels: + label, cat = self._label_cat.find(label_name) + if not cat: + self._label_cat.add(label_name) + label, _ = self._label_cat.find(label_name) + labels.append(Label(label=label)) + else: + label_name = str(datas[indices["label"]]) + label, cat = self._label_cat.find(label_name) + if not cat: + self._label_cat.add(label_name) + label, _ = self._label_cat.find(label_name) else: _, cat = self._label_cat.find("object") if not cat: @@ -91,7 +109,11 @@ def _load_annotations(self, datas: list, indices: Dict[str, int], bbox_flag: boo label = 0 if "label" in indices and not bbox_flag: + label_indices = indices["label"] + if isinstance(label_indices, dict): + return labels return Label(label=label) + if bbox_flag: if "bbox" in indices: coords = self._parse_bbox_coords(datas[indices["bbox"]]) @@ -125,7 +147,14 @@ def _load_items(self, ann_file: str, columns: Dict[str, Union[str, list]]): indices = {"media": df_fields.index(columns["media"])} if "label" in columns: - indices.update({"label": df_fields.index(columns["label"])}) + label_columns = columns["label"] + if isinstance(label_columns, list): + indices_label = {} + for label in label_columns: + indices_label[label] = df_fields.index(label) + indices.update({"label": indices_label}) + else: + indices.update({"label": df_fields.index(label_columns)}) bbox_flag = False bbox_index = columns.get("bbox") @@ -165,16 +194,30 @@ def _load_items(self, ann_file: str, columns: Dict[str, Union[str, list]]): continue ann = self._load_annotations(data_info, indices, bbox_flag) - self._ann_types.add(ann.type) - if item_id in items: - items[item_id].annotations.append(ann) + if isinstance(ann, list): + for label in ann: + self._ann_types.add(label.type) + if item_id in items: + for label in ann: + items[item_id].annotations.append(label) + else: + items[item_id] = DatasetItem( + id=item_id, + subset=self._subset, + media=Image.from_file(path=media_path), + annotations=ann, + ) else: - items[item_id] = DatasetItem( - id=item_id, - subset=self._subset, - media=Image.from_file(path=media_path), - annotations=[ann], - ) + self._ann_types.add(ann.type) + if item_id in items: + items[item_id].annotations.append(ann) + else: + items[item_id] = DatasetItem( + id=item_id, + subset=self._subset, + media=Image.from_file(path=media_path), + annotations=[ann], + ) return items.values() def categories(self): diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/ann.csv b/tests/assets/kaggle_dataset/image_csv_multi_label/ann.csv new file mode 100644 index 0000000000..57b6540a15 --- /dev/null +++ b/tests/assets/kaggle_dataset/image_csv_multi_label/ann.csv @@ -0,0 +1,7 @@ +image_name,dog,cat,person +1.jpg,1,0,0 +2.jpg,0,1,0 +3.jpg,0,0,1 +4.jpg,1,1,0 +5.jpg,1,0,1 +6.jpg,0,1,1 diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/ann_wo_ext.csv b/tests/assets/kaggle_dataset/image_csv_multi_label/ann_wo_ext.csv new file mode 100644 index 0000000000..dd01be80e0 --- /dev/null +++ b/tests/assets/kaggle_dataset/image_csv_multi_label/ann_wo_ext.csv @@ -0,0 +1,7 @@ +image_name,dog,cat,person +1,1,0,0 +2,0,1,0 +3,0,0,1 +4,1,1,0 +5,1,0,1 +6,0,1,1 diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/1.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/1.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/1.jpg differ diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/2.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/2.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/2.jpg differ diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/3.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/3.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/3.jpg differ diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/4.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/4.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/4.jpg differ diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/5.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/5.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/5.jpg differ diff --git a/tests/assets/kaggle_dataset/image_csv_multi_label/images/6.jpg b/tests/assets/kaggle_dataset/image_csv_multi_label/images/6.jpg new file mode 100644 index 0000000000..8689b95631 Binary files /dev/null and b/tests/assets/kaggle_dataset/image_csv_multi_label/images/6.jpg differ diff --git a/tests/unit/data_formats/test_kaggle.py b/tests/unit/data_formats/test_kaggle.py index 90071c71fc..262c741e74 100644 --- a/tests/unit/data_formats/test_kaggle.py +++ b/tests/unit/data_formats/test_kaggle.py @@ -20,6 +20,9 @@ from tests.utils.test_utils import compare_datasets DUMMY_DATASET_IMAGE_CSV_DIR = get_test_asset_path("kaggle_dataset", "image_csv") +DUMMY_DATASET_IMAGE_CSV_MULTI_LB_DIR = get_test_asset_path( + "kaggle_dataset", "image_csv_multi_label" +) DUMMY_DATASET_IMAGE_CSV_DET_DIR = get_test_asset_path("kaggle_dataset", "image_csv_det") DUMMY_DATASET_IMAGE_TXT_DIR = get_test_asset_path("kaggle_dataset", "image_txt") DUMMY_DATASET_IMAGE_TXT_DET_DIR = get_test_asset_path("kaggle_dataset", "image_txt_det") @@ -72,6 +75,51 @@ def fxt_img_dataset() -> Dataset: ) +@pytest.fixture +def fxt_img_multi_label_dataset() -> Dataset: + return Dataset.from_iterable( + [ + DatasetItem( + id="1", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=0)], + ), + DatasetItem( + id="2", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=1)], + ), + DatasetItem( + id="3", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=2)], + ), + DatasetItem( + id="4", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=0), Label(label=1)], + ), + DatasetItem( + id="5", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=0), Label(label=2)], + ), + DatasetItem( + id="6", + subset="default", + media=Image.from_numpy(data=np.ones((5, 10, 3))), + annotations=[Label(label=1), Label(label=2)], + ), + ], + categories=["dog", "cat", "person"], + ) + + @pytest.fixture def fxt_img_det_dataset() -> Dataset: return Dataset.from_iterable( @@ -321,6 +369,8 @@ def fxt_coco_dataset() -> Dataset: IDS = [ "IMAGE_CSV", "IMAGE_CSV_WO_EXT", + "IMAGE_CSV_MULTI_LB", + "IMAGE_CSV_MULTI_LB_WO_EXT", "IMAGE_CSV_DET", "IMAGE_CSV_DET2", "IMAGE_CSV_DET3", @@ -372,6 +422,26 @@ def test_can_detect(self, fxt_dataset_dir: str): "columns": {"media": "image_name", "label": "label_name"}, }, ), + ( + DUMMY_DATASET_IMAGE_CSV_MULTI_LB_DIR, + "images", + "fxt_img_multi_label_dataset", + KaggleImageCsvBase, + { + "ann_file": osp.join(DUMMY_DATASET_IMAGE_CSV_MULTI_LB_DIR, "ann.csv"), + "columns": {"media": "image_name", "label": ["dog", "cat", "person"]}, + }, + ), + ( + DUMMY_DATASET_IMAGE_CSV_MULTI_LB_DIR, + "images", + "fxt_img_multi_label_dataset", + KaggleImageCsvBase, + { + "ann_file": osp.join(DUMMY_DATASET_IMAGE_CSV_MULTI_LB_DIR, "ann_wo_ext.csv"), + "columns": {"media": "image_name", "label": ["dog", "cat", "person"]}, + }, + ), ( DUMMY_DATASET_IMAGE_CSV_DET_DIR, "images",