Skip to content

Commit

Permalink
add classification_to_kvp unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
nisyad-ms committed Aug 30, 2024
1 parent 598c251 commit 756715d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
18 changes: 18 additions & 0 deletions tests/test_ic_od_to_kvp_wrapper/test_classification_as_kvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ def test_simple(self):

print(kvp_dataset)

self.assertIsInstance(kvp_dataset, ClassificationAsKeyValuePairDataset)
self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR)
self.assertIn("name", kvp_dataset.dataset_info.schema)
self.assertIn("description", kvp_dataset.dataset_info.schema)
self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema)

self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"],
{"className": {
"type": "string",
"description": "Class name that the image belongs to.",
"classes": {
"1-class": {},
"2-class": {},
"3-class": {},
}
}
})


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
from copy import deepcopy
from typing import Any, Dict, List

from vision_datasets.common import DatasetTypes, KeyValuePairDatasetInfo, VisionDataset
from vision_datasets.key_value_pair import (
Expand Down Expand Up @@ -37,11 +38,11 @@ def __init__(self, detection_dataset: VisionDataset):
"""
Initializes an instance of the ClassificationAsKeyValuePairDataset class.
Args:
detection_dataset (VisionDataset): The classification dataset to convert to key-value pair dataset.
detection_dataset (VisionDataset): The detection dataset to convert to key-value pair dataset.
"""

if detection_dataset is None or detection_dataset.dataset_info.type not in {DatasetTypes.IMAGE_OBJECT_DETECTION}:
raise ValueError
raise ValueError("DetectionAsKeyValuePairDataset only supports Image Object Detection datasets.")

# Generate schema and update dataset info
detection_dataset = deepcopy(detection_dataset)
Expand All @@ -60,7 +61,6 @@ def __init__(self, detection_dataset: VisionDataset):
annotations = []
for id, img in enumerate(detection_dataset.dataset_manifest.images, 1):
bboxes = [box.label_data for box in img.labels]
# label_names = [self.class_id_to_names[box[0]] for box in bboxes]

kvp_label_data = self.construct_kvp_label_data(bboxes)
img_ids = [self.img_id_to_pos[img.id]] # 0-based index
Expand All @@ -73,20 +73,20 @@ def __init__(self, detection_dataset: VisionDataset):
dataset_manifest = KeyValuePairDatasetManifest(detection_dataset.dataset_manifest.images, annotations, schema, additional_info=detection_dataset.dataset_manifest.additional_info)
super().__init__(dataset_info, dataset_manifest, dataset_resources=detection_dataset.dataset_resources)

def construct_schema(self, class_names: typing.List[str]) -> typing.Dict[str, typing.Any]:
schema: typing.Dict[str, typing.Any] = BASE_DETECTION_SCHEMA # initialize with base schema
def construct_schema(self, class_names: List[str]) -> Dict[str, Any]:
schema: Dict[str, Any] = BASE_DETECTION_SCHEMA # initialize with base schema
schema["fieldSchema"][f"{BBOXES_KEY}"]["items"]["classes"] = {c: {} for c in class_names}
return schema

def construct_kvp_label_data(self, bboxes: typing.List[typing.List[int]]):
def construct_kvp_label_data(self, bboxes: List[List[int]]):
"""
Convert the detection dataset label_name to the desired format for KVP annnotation as defined by the BASE_DETECTION_SCHEMA.
E.g. {"fields": {"bboxes": {"value": [{"value": "class1", "groundings" : [[10,10,20,20]]},
{"value": "class2", "groundings" : [[0,0,20,20], [20,20,30,30]]}]
"text": None}
"""

label_wise_bboxes = self._sort_bboxes_label_wise(bboxes)
label_wise_bboxes = self.sort_bboxes_label_wise(bboxes)

return {
f"{KeyValuePairLabelManifest.LABEL_KEY}": {
Expand All @@ -97,7 +97,7 @@ def construct_kvp_label_data(self, bboxes: typing.List[typing.List[int]]):
f"{KeyValuePairLabelManifest.TEXT_INPUT_KEY}": None
}

def _sort_bboxes_label_wise(self, bboxes: typing.List[typing.List[int]]):
def sort_bboxes_label_wise(self, bboxes: List[List[int]]) -> Dict[str, List[List[int]]]:
"""
Convert a list of bounding boxes to a dictionary with class name as key and list of bounding boxes as value.
Expand Down

0 comments on commit 756715d

Please sign in to comment.