Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KITTI 3D format support #1616

Closed
4 changes: 2 additions & 2 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/init@4dd16135b69a43b6c8efb853346f8437d92d3c93 # v3.26.6
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
Expand All @@ -73,7 +73,7 @@ jobs:
python -m build

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/analyze@4dd16135b69a43b6c8efb853346f8437d92d3c93 # v3.26.6
with:
category: "/language:${{matrix.language}}"
- name: Generate Security Report
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ jobs:
file_glob: true
- name: Publish package distributions to PyPI
if: ${{ steps.check-tag.outputs.match != '' }}
uses: pypa/gh-action-pypi-publish@v1.9.0
uses: pypa/gh-action-pypi-publish@v1.10.1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
- name: Publish package distributions to TestPyPI
if: ${{ steps.check-tag.outputs.match == '' }}
uses: pypa/gh-action-pypi-publish@v1.9.0
uses: pypa/gh-action-pypi-publish@v1.10.1
with:
password: ${{ secrets.TESTPYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ jobs:

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/upload-sarif@4dd16135b69a43b6c8efb853346f8437d92d3c93 # v3.26.6
with:
sarif_file: results.sarif
2 changes: 1 addition & 1 deletion src/datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def merge(cls, envs: Sequence["Environment"]) -> "Environment":
merged = Environment()

def _register(registry: PluginRegistry):
merged.register_plugins(plugin for plugin in registry)
merged.register_plugins(list(registry._items.values()))

for env in envs:
_register(env.extractors)
Expand Down
15 changes: 8 additions & 7 deletions src/datumaro/components/hl_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def merge(
merger = get_merger(merge_policy, **kwargs)
merged = merger(*datasets)
env = Environment.merge(
dataset.env
for dataset in datasets
if hasattr(
dataset, "env"
) # TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
[
dataset.env
for dataset in datasets
if hasattr(dataset, "env")
# TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
]
)
if report_path:
merger.save_merge_report(report_path)
Expand Down
Empty file.
122 changes: 122 additions & 0 deletions src/datumaro/plugins/data_formats/kitti_3d/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

import glob
import os.path as osp
from typing import List, Optional, Type, TypeVar

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import InvalidAnnotationError
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image
from datumaro.util.image import find_images

from .format import Kitti3dPath

T = TypeVar("T")


class Kitti3dBase(SubsetBase):
# http://www.cvlibs.net/datasets/kitti/raw_data.php
# https://s3.eu-central-1.amazonaws.com/avg-kitti/devkit_raw_data.zip
# Check cpp header implementation for field meaning

def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
):
assert osp.isfile(path), path
super().__init__(subset=subset, ctx=ctx)

self._path = path
self._categories = {AnnotationType.label: LabelCategories()}
self._items = self._load()

def _load(self) -> List[DatasetItem]:
items = []
image_dir = osp.join(self._path, Kitti3dPath.IMAGE_DIR)
image_path_by_id = {
osp.splitext(osp.relpath(p, image_dir))[0]: p
for p in find_images(image_dir, recursive=True)
}

ann_dir = osp.join(self._path, Kitti3dPath.LABEL_DIR)
label_categories = self._categories[AnnotationType.label]
for labels_path in sorted(glob.glob(osp.join(ann_dir, "*.txt"), recursive=True)):
item_id = osp.splitext(osp.relpath(labels_path, ann_dir))[0]
anns = []

with open(labels_path, "r", encoding="utf-8") as f:
lines = f.readlines()

for line_idx, line in enumerate(lines):
line = line.split()
assert len(line) == 15 or len(line) == 16

label_name = line[0]
label_id = label_categories.find(label_name)[0]
if label_id is None:
label_id = label_categories.add(label_name)

x1 = self._parse_field(line[4], float, "bbox left-top x")
y1 = self._parse_field(line[5], float, "bbox left-top y")
x2 = self._parse_field(line[6], float, "bbox right-bottom x")
y2 = self._parse_field(line[7], float, "bbox right-bottom y")

attributes = {}
attributes["truncated"] = self._parse_field(line[1], float, "truncated")
attributes["occluded"] = self._parse_field(line[2], int, "occluded")
attributes["alpha"] = self._parse_field(line[3], float, "alpha")

height_3d = self._parse_field(line[8], float, "height (in meters)")
width_3d = self._parse_field(line[9], float, "width (in meters)")
length_3d = self._parse_field(line[10], float, "length (in meters)")

x_3d = self._parse_field(line[11], float, "x (in meters)")
y_3d = self._parse_field(line[12], float, "y (in meters)")
z_3d = self._parse_field(line[13], float, "z (in meters)")

yaw_angle = self._parse_field(line[14], float, "rotation_y")

attributes["dimensions"] = [height_3d, width_3d, length_3d]
attributes["location"] = [x_3d, y_3d, z_3d]
attributes["rotation_y"] = yaw_angle

if len(line) == 16:
attributes["score"] = self._parse_field(line[15], float, "score")

anns.append(
Bbox(
x=x1,
y=y1,
w=x2 - x1,
h=y2 - y1,
id=line_idx,
attributes=attributes,
label=label_id,
)
)
self._ann_types.add(AnnotationType.bbox)

image = image_path_by_id.pop(item_id, None)
if image:
image = Image.from_file(path=image)

items.append(
DatasetItem(id=item_id, annotations=anns, media=image, subset=self._subset)
)

return items

def _parse_field(self, value: str, desired_type: Type[T], field_name: str) -> T:
try:
return desired_type(value)
except Exception as e:
raise InvalidAnnotationError(
f"Can't parse {field_name} from '{value}'. Expected {desired_type}"
) from e
12 changes: 12 additions & 0 deletions src/datumaro/plugins/data_formats/kitti_3d/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp


class Kitti3dPath:
PCD_DIR = osp.join("velodyne_points", "data")
IMAGE_DIR = "image_2"
LABEL_DIR = "label_2"
CALIB_DIR = "calib"
43 changes: 43 additions & 0 deletions src/datumaro/plugins/data_formats/kitti_3d/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import List

from datumaro.components.errors import DatasetImportError
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import Importer

from .format import Kitti3dPath


class Kitti3dImporter(Importer):
_ANNO_EXT = ".txt"

@classmethod
def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence:
with context.require_any():
with context.alternative():
cls._check_ann_file(context.require_file(f"{Kitti3dPath.LABEL_DIR}/*.txt"), context)

return FormatDetectionConfidence.MEDIUM

@classmethod
def _check_ann_file(cls, fpath: str, context: FormatDetectionContext) -> bool:
with context.probe_text_file(
fpath, "Requirements for the annotation file of Kitti 3D format"
) as fp:
for line in fp:
fields = line.rstrip("\n").split(" ")
if len(fields) == 15 or len(fields) == 16:
return True
raise DatasetImportError(
f"Kitti 3D format txt file should have 15 or 16 fields for "
f"each line, but the read line has {len(fields)} fields: "
f"fields={fields}."
)
raise DatasetImportError("Empty file is not allowed.")

@classmethod
def get_file_extensions(cls) -> List[str]:
return [cls._ANNO_EXT]
32 changes: 19 additions & 13 deletions src/datumaro/plugins/specs.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
[
{
"import_path": "datumaro.plugins.accuracy_checker_plugin.ac_launcher.AcLauncher",
"plugin_name": "ac",
"plugin_type": "Launcher",
"extra_deps": [
"tensorflow",
"openvino.tools.accuracy_checker"
]
},
{
"import_path": "datumaro.plugins.configurable_validator.ConfigurableValidator",
"plugin_name": "configurable",
Expand Down Expand Up @@ -799,6 +790,21 @@
]
}
},
{
"import_path": "datumaro.plugins.data_formats.kitti_3d.base.Kitti3dBase",
"plugin_name": "kitti3d",
"plugin_type": "DatasetBase"
},
{
"import_path": "datumaro.plugins.data_formats.kitti_3d.importer.Kitti3dImporter",
"plugin_name": "kitti3d",
"plugin_type": "Importer",
"metadata": {
"file_extensions": [
".txt"
]
}
},
{
"import_path": "datumaro.plugins.data_formats.kitti_raw.base.KittiRawBase",
"plugin_name": "kitti_raw",
Expand Down Expand Up @@ -1855,13 +1861,13 @@
"plugin_type": "Transform"
},
{
"import_path": "datumaro.plugins.transforms.Correct",
"plugin_name": "correct",
"import_path": "datumaro.plugins.transforms.Clean",
"plugin_name": "clean",
"plugin_type": "Transform"
},
{
"import_path": "datumaro.plugins.transforms.Clean",
"plugin_name": "clean",
"import_path": "datumaro.plugins.transforms.Correct",
"plugin_name": "correct",
"plugin_type": "Transform"
},
{
Expand Down
17 changes: 16 additions & 1 deletion tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

import datumaro.components.lazy_plugin
from datumaro.components.environment import Environment, PluginRegistry
from datumaro.components.environment import DEFAULT_ENVIRONMENT, Environment, PluginRegistry
from datumaro.components.exporter import Exporter

real_find_spec = datumaro.components.lazy_plugin.find_spec

Expand Down Expand Up @@ -77,3 +78,17 @@ def test_extra_deps_req(self, fxt_tf_failure_env):
)

assert "tf_detection_api" not in loaded_plugin_names

def test_merge_default_env(self):
merged_env = Environment.merge([DEFAULT_ENVIRONMENT, DEFAULT_ENVIRONMENT])
assert merged_env is DEFAULT_ENVIRONMENT

def test_merge_custom_env(self):
class TestPlugin(Exporter):
pass

envs = [Environment(), Environment()]
envs[0].exporters.register("test_plugin", TestPlugin)

merged = Environment.merge(envs)
assert "test_plugin" in merged.exporters
Loading