Skip to content

Commit

Permalink
fix bug in generating statistics for Box2dDataProvider
Browse files Browse the repository at this point in the history
Summary: Fix bug in generating statistics for Box2dDataProvider. We were using sum instead of len.

Reviewed By: SeaOtocinclus

Differential Revision: D61223159

fbshipit-source-id: 9dc35e8c2b8cdecc397117f9c59b6a8cdcbecd07
  • Loading branch information
Prithviraj Banerjee authored and facebook-github-bot committed Aug 14, 2024
1 parent 7dcb1da commit a88cf76
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
4 changes: 2 additions & 2 deletions hot3d/data_loaders/HandBox2dDataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def stream_ids(self) -> List[StreamId]:

def get_data_statistics(self) -> Dict[str, Any]:
"""
Returns the stats of the trajectory
Returns the stats for Hand 2D bounding boxes
"""
stats = {}
stats["num_frames"] = {
k: sum(v) for k, v in self._sorted_timestamp_ns_list.items()
k: len(v) for k, v in self._sorted_timestamp_ns_list.items()
}
stats["stream_ids"] = [str(x) for x in self.stream_ids]
return stats
Expand Down
4 changes: 2 additions & 2 deletions hot3d/data_loaders/ObjectBox2dDataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def object_uids(self) -> Set[str]:

def get_data_statistics(self) -> Dict[str, Any]:
"""
Returns the stats of the trajectory
Returns the stats for Object 2D bounding boxes
"""
stats = {}
stats["num_frames"] = {
k: sum(v) for k, v in self._sorted_timestamp_ns_list.items()
k: len(v) for k, v in self._sorted_timestamp_ns_list.items()
}
stats["stream_ids"] = [str(x) for x in self.stream_ids]
stats["num_objects"] = len(self.object_uids)
Expand Down
6 changes: 6 additions & 0 deletions hot3d/data_loaders/tests/test_HandBox2dDataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,9 @@ def test_provider_aria_recording(self) -> None:

self.assertIsNotNone(box2d_collection_with_dt)
self.assertIsNotNone(box2d_collection_with_dt.box2d_collection)

data_statistics = provider.get_data_statistics()
print(f"data_statistics: {data_statistics}")
self.assertEquals(len(data_statistics["num_frames"]), 3)
self.assertEquals(data_statistics["num_frames"]["214-1"], 34)
self.assertEquals(len(data_statistics["stream_ids"]), 3)
8 changes: 8 additions & 0 deletions hot3d/data_loaders/tests/test_ObjectBox2dDataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def test_provider_aria_recording(self) -> None:
box2d_collection_with_dt.box2d_collection.object_uid_list
)
self.assertGreater(len(object_uids_at_query_timestamp), 0)

data_statistics = provider.get_data_statistics()
print(f"data_statistics: {data_statistics}")
self.assertEquals(len(data_statistics["num_frames"]), 3)
self.assertEquals(data_statistics["num_frames"]["214-1"], 34)
self.assertEquals(data_statistics["num_objects"], 6)
self.assertEquals(len(data_statistics["stream_ids"]), 3)
self.assertEquals(len(data_statistics["object_uids"]), 6)

0 comments on commit a88cf76

Please sign in to comment.