Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sooahleex committed Sep 23, 2024
1 parent f5182af commit a84eaa6
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 33 deletions.
88 changes: 55 additions & 33 deletions src/datumaro/plugins/data_formats/kitti_3d/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# SPDX-License-Identifier: MIT

import glob
import logging
import os.path as osp
from typing import List, Optional, Type, TypeVar

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import InvalidAnnotationError
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image
from datumaro.components.media import Image, PointCloud
from datumaro.util.image import find_images

from .format import Kitti3dPath
Expand All @@ -31,10 +32,12 @@ def __init__(
ctx: Optional[ImportContext] = None,
):
assert osp.isdir(path), path
super().__init__(subset=subset, ctx=ctx)
super().__init__(subset=subset, media_type=PointCloud, ctx=ctx)

self._path = path
self._categories = {AnnotationType.label: LabelCategories()}

common_attrs = {"truncated", "occluded", "alpha", "dimensions", "location", "rotation_y"}
self._categories = {AnnotationType.label: LabelCategories(attributes=common_attrs)}
self._items = self._load_items()

def _load_items(self) -> List[DatasetItem]:
Expand All @@ -47,48 +50,56 @@ def _load_items(self) -> List[DatasetItem]:

ann_dir = osp.join(self._path, Kitti3dPath.LABEL_DIR)
label_categories = self._categories[AnnotationType.label]

for labels_path in sorted(glob.glob(osp.join(ann_dir, "*.txt"), recursive=True)):
item_id = osp.splitext(osp.relpath(labels_path, ann_dir))[0]
anns = []

with open(labels_path, "r", encoding="utf-8") as f:
lines = f.readlines()
try:
with open(labels_path, "r", encoding="utf-8") as f:
lines = f.readlines()
except IOError as e:
logging.error(f"Error reading file {labels_path}: {e}")
continue

for line_idx, line in enumerate(lines):
line = line.split()
assert len(line) == 15 or len(line) == 16
if len(line) not in [15, 16]:
logging.warning(
f"Unexpected line length {len(line)} in file {labels_path} at line {line_idx + 1}"
)
continue

label_name = line[0]
label_id = label_categories.find(label_name)[0]
if label_id is None:
label_id = label_categories.add(label_name)

x1 = self._parse_field(line[4], float, "bbox left-top x")
y1 = self._parse_field(line[5], float, "bbox left-top y")
x2 = self._parse_field(line[6], float, "bbox right-bottom x")
y2 = self._parse_field(line[7], float, "bbox right-bottom y")

attributes = {}
attributes["truncated"] = self._parse_field(line[1], float, "truncated")
attributes["occluded"] = self._parse_field(line[2], int, "occluded")
attributes["alpha"] = self._parse_field(line[3], float, "alpha")

height_3d = self._parse_field(line[8], float, "height (in meters)")
width_3d = self._parse_field(line[9], float, "width (in meters)")
length_3d = self._parse_field(line[10], float, "length (in meters)")

x_3d = self._parse_field(line[11], float, "x (in meters)")
y_3d = self._parse_field(line[12], float, "y (in meters)")
z_3d = self._parse_field(line[13], float, "z (in meters)")

yaw_angle = self._parse_field(line[14], float, "rotation_y")

attributes["dimensions"] = [height_3d, width_3d, length_3d]
attributes["location"] = [x_3d, y_3d, z_3d]
attributes["rotation_y"] = yaw_angle

if len(line) == 16:
attributes["score"] = self._parse_field(line[15], float, "score")
try:
x1 = self._parse_field(line[4], float, "bbox left-top x")
y1 = self._parse_field(line[5], float, "bbox left-top y")
x2 = self._parse_field(line[6], float, "bbox right-bottom x")
y2 = self._parse_field(line[7], float, "bbox right-bottom y")

attributes = {
"truncated": self._parse_field(line[1], float, "truncated"),
"occluded": self._parse_field(line[2], int, "occluded"),
"alpha": self._parse_field(line[3], float, "alpha"),
"dimensions": [
self._parse_field(line[8], float, "height (in meters)"),
self._parse_field(line[9], float, "width (in meters)"),
self._parse_field(line[10], float, "length (in meters)"),
],
"location": [
self._parse_field(line[11], float, "x (in meters)"),
self._parse_field(line[12], float, "y (in meters)"),
self._parse_field(line[13], float, "z (in meters)"),
],
"rotation_y": self._parse_field(line[14], float, "rotation_y"),
}
except ValueError as e:
logging.error(f"Error parsing line {line_idx + 1} in file {labels_path}: {e}")
continue

anns.append(
Bbox(
Expand All @@ -108,7 +119,18 @@ def _load_items(self) -> List[DatasetItem]:
image = Image.from_file(path=image)

items.append(
DatasetItem(id=item_id, annotations=anns, media=image, subset=self._subset)
DatasetItem(
id=item_id,
subset=self._subset,
media=PointCloud.from_file(
path=osp.join(self._path, Kitti3dPath.PCD_DIR, item_id + ".bin"),
extra_images=[image],
),
attributes={
"calib_path": osp.join(self._path, Kitti3dPath.CALIB_DIR, item_id + ".txt")
},
annotations=anns,
)
)

return items
Expand Down
7 changes: 7 additions & 0 deletions tests/assets/kitti_dataset/kitti_3d/training/calib/000001.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
P0: 7.215377000000e+02 0.000000000000e+00 6.095593000000e+02 0.000000000000e+00 0.000000000000e+00 7.215377000000e+02 1.728540000000e+02 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00
P1: 7.215377000000e+02 0.000000000000e+00 6.095593000000e+02 -3.875744000000e+02 0.000000000000e+00 7.215377000000e+02 1.728540000000e+02 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00
P2: 7.215377000000e+02 0.000000000000e+00 6.095593000000e+02 4.485728000000e+01 0.000000000000e+00 7.215377000000e+02 1.728540000000e+02 2.163791000000e-01 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 2.745884000000e-03
P3: 7.215377000000e+02 0.000000000000e+00 6.095593000000e+02 -3.395242000000e+02 0.000000000000e+00 7.215377000000e+02 1.728540000000e+02 2.199936000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 2.729905000000e-03
R0_rect: 9.999239000000e-01 9.837760000000e-03 -7.445048000000e-03 -9.869795000000e-03 9.999421000000e-01 -4.278459000000e-03 7.402527000000e-03 4.351614000000e-03 9.999631000000e-01
Tr_velo_to_cam: 7.533745000000e-03 -9.999714000000e-01 -6.166020000000e-04 -4.069766000000e-03 1.480249000000e-02 7.280733000000e-04 -9.998902000000e-01 -7.631618000000e-02 9.998621000000e-01 7.523790000000e-03 1.480755000000e-02 -2.717806000000e-01
Tr_imu_to_velo: 9.999976000000e-01 7.553071000000e-04 -2.035826000000e-03 -8.086759000000e-01 -7.854027000000e-04 9.998898000000e-01 -1.482298000000e-02 3.195559000000e-01 2.024406000000e-03 1.482454000000e-02 9.998881000000e-01 -7.997231000000e-01
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Truck 0.00 0 -1.57 600 150 630 190 2.85 2.63 12.34 0.47 1.49 69.44 -1.56
Car 0.00 3 -1.65 650 160 700 200 1.86 0.60 2.02 4.59 1.32 45.84 -1.55
DontCare -1 -1 -10 500 170 590 190 -1 -1 -1 -1000 -1000 -1000 -10
Binary file not shown.
Binary file not shown.
Binary file not shown.
116 changes: 116 additions & 0 deletions tests/unit/test_kitti_3d_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os.path as osp
from unittest import TestCase

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.environment import Environment
from datumaro.components.media import Image, PointCloud
from datumaro.components.project import Dataset
from datumaro.plugins.data_formats.kitti_3d.importer import Kitti3dImporter

from tests.requirements import Requirements, mark_requirement
from tests.utils.assets import get_test_asset_path
from tests.utils.test_utils import compare_datasets_3d

DUMMY_DATASET_DIR = get_test_asset_path("kitti_dataset", "kitti_3d", "training")


class Kitti3DImporterTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_detect(self):
detected_formats = Environment().detect_dataset(DUMMY_DATASET_DIR)
self.assertEqual([Kitti3dImporter.NAME], detected_formats)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_load(self):
"""
<b>Description:</b>
Ensure that the dataset can be loaded correctly from the KITTI3D format.
<b>Expected results:</b>
The loaded dataset should have the same number of data items as the expected dataset.
The data items in the loaded dataset should have the same attributes and values as the expected data items.
The point clouds and images associated with the data items should be loaded correctly.
<b>Steps:</b>
1. Prepare an expected dataset with known data items, point clouds, images, and attributes.
2. Load the dataset from the KITTI3D format.
3. Compare the loaded dataset with the expected dataset.
"""
pcd1 = osp.join(DUMMY_DATASET_DIR, "velodyne", "000001.bin")

image1 = Image.from_file(path=osp.join(DUMMY_DATASET_DIR, "image_2", "000001.png"))

expected_label_cat = LabelCategories(
attributes={"occluded", "truncated", "alpha", "dimensions", "location", "rotation_y"}
)
expected_label_cat.add("Truck")
expected_label_cat.add("Car")
expected_label_cat.add("DontCare")
expected_dataset = Dataset.from_iterable(
[
DatasetItem(
id="000001",
annotations=[
Bbox(
600, # x1
150, # y1
30, # x2-x1
40, # y2-y1
label=0,
id=0,
attributes={
"truncated": 0.0,
"occluded": 0,
"alpha": -1.57,
"dimensions": [2.85, 2.63, 12.34],
"location": [0.47, 1.49, 69.44],
"rotation_y": -1.56,
},
z_order=0,
),
Bbox(
650, # x1
160, # y1
50, # x2-x1
40, # y2-y1
label=1,
id=1,
attributes={
"truncated": 0.0,
"occluded": 3,
"alpha": -1.65,
"dimensions": [1.86, 0.6, 2.02],
"location": [4.59, 1.32, 45.84],
"rotation_y": -1.55,
},
z_order=0,
),
Bbox(
500, # x1
170, # y1
90, # x2-x1
20, # y2-y1
label=2,
id=2,
attributes={
"truncated": -1.0,
"occluded": -1,
"alpha": -10.0,
"dimensions": [-1.0, -1.0, -1.0],
"location": [-1000.0, -1000.0, -1000.0],
"rotation_y": -10.0,
},
),
],
media=PointCloud.from_file(path=pcd1, extra_images=[image1]),
attributes={"calib_path": osp.join(DUMMY_DATASET_DIR, "calib", "000001.txt")},
),
],
categories={AnnotationType.label: expected_label_cat},
media_type=PointCloud,
)

parsed_dataset = Dataset.import_from(DUMMY_DATASET_DIR, "kitti3d")

compare_datasets_3d(self, expected_dataset, parsed_dataset, require_point_cloud=True)

0 comments on commit a84eaa6

Please sign in to comment.