From a88cf76251e2f2e36050d20f8f209fb340bc3985 Mon Sep 17 00:00:00 2001 From: Prithviraj Banerjee Date: Wed, 14 Aug 2024 16:36:51 -0700 Subject: [PATCH] fix bug in generating statistics for Box2dDataProvider Summary: Fix bug in generating statistics for Box2dDataProvider. We were using sum instead of len. Reviewed By: SeaOtocinclus Differential Revision: D61223159 fbshipit-source-id: 9dc35e8c2b8cdecc397117f9c59b6a8cdcbecd07 --- hot3d/data_loaders/HandBox2dDataProvider.py | 4 ++-- hot3d/data_loaders/ObjectBox2dDataProvider.py | 4 ++-- hot3d/data_loaders/tests/test_HandBox2dDataProvider.py | 6 ++++++ hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py | 8 ++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/hot3d/data_loaders/HandBox2dDataProvider.py b/hot3d/data_loaders/HandBox2dDataProvider.py index 027eed5..6aae7e4 100644 --- a/hot3d/data_loaders/HandBox2dDataProvider.py +++ b/hot3d/data_loaders/HandBox2dDataProvider.py @@ -82,11 +82,11 @@ def stream_ids(self) -> List[StreamId]: def get_data_statistics(self) -> Dict[str, Any]: """ - Returns the stats of the trajectory + Returns the stats for Hand 2D bounding boxes """ stats = {} stats["num_frames"] = { - k: sum(v) for k, v in self._sorted_timestamp_ns_list.items() + k: len(v) for k, v in self._sorted_timestamp_ns_list.items() } stats["stream_ids"] = [str(x) for x in self.stream_ids] return stats diff --git a/hot3d/data_loaders/ObjectBox2dDataProvider.py b/hot3d/data_loaders/ObjectBox2dDataProvider.py index d8b705c..490280f 100644 --- a/hot3d/data_loaders/ObjectBox2dDataProvider.py +++ b/hot3d/data_loaders/ObjectBox2dDataProvider.py @@ -99,11 +99,11 @@ def object_uids(self) -> Set[str]: def get_data_statistics(self) -> Dict[str, Any]: """ - Returns the stats of the trajectory + Returns the stats for Object 2D bounding boxes """ stats = {} stats["num_frames"] = { - k: sum(v) for k, v in self._sorted_timestamp_ns_list.items() + k: len(v) for k, v in self._sorted_timestamp_ns_list.items() } stats["stream_ids"] = [str(x) for x in self.stream_ids] stats["num_objects"] = len(self.object_uids) diff --git a/hot3d/data_loaders/tests/test_HandBox2dDataProvider.py b/hot3d/data_loaders/tests/test_HandBox2dDataProvider.py index 6cfbe2b..33d6f33 100644 --- a/hot3d/data_loaders/tests/test_HandBox2dDataProvider.py +++ b/hot3d/data_loaders/tests/test_HandBox2dDataProvider.py @@ -63,3 +63,9 @@ def test_provider_aria_recording(self) -> None: self.assertIsNotNone(box2d_collection_with_dt) self.assertIsNotNone(box2d_collection_with_dt.box2d_collection) + + data_statistics = provider.get_data_statistics() + print(f"data_statistics: {data_statistics}") + self.assertEquals(len(data_statistics["num_frames"]), 3) + self.assertEquals(data_statistics["num_frames"]["214-1"], 34) + self.assertEquals(len(data_statistics["stream_ids"]), 3) diff --git a/hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py b/hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py index 422164c..93a8a5e 100644 --- a/hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py +++ b/hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py @@ -69,3 +69,11 @@ def test_provider_aria_recording(self) -> None: box2d_collection_with_dt.box2d_collection.object_uid_list ) self.assertGreater(len(object_uids_at_query_timestamp), 0) + + data_statistics = provider.get_data_statistics() + print(f"data_statistics: {data_statistics}") + self.assertEquals(len(data_statistics["num_frames"]), 3) + self.assertEquals(data_statistics["num_frames"]["214-1"], 34) + self.assertEquals(data_statistics["num_objects"], 6) + self.assertEquals(len(data_statistics["stream_ids"]), 3) + self.assertEquals(len(data_statistics["object_uids"]), 6)