Skip to content

Commit

Permalink
Fix evaluate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Nov 7, 2024
1 parent 6d39074 commit a570f69
Showing 1 changed file with 45 additions and 22 deletions.
67 changes: 45 additions & 22 deletions tests/test_unit/test_evaluate_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,40 @@


@pytest.fixture
def evaluation():
test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv"
def tracker_evaluate_interface():
annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv"
return TrackerEvaluate(
test_csv_file,
predicted_boxes_id=[],
annotations_file_csv,
predicted_boxes_dict={},
iou_threshold=0.1,
tracking_output_dir="/path/output",
)


def test_get_ground_truth_data(evaluation):
ground_truth_dict = evaluation.get_ground_truth_data()
def test_get_ground_truth_data_structure(tracker_evaluate_interface):
"""Test the loaded ground truth data has the expected structure."""
# Get ground truth data dict
ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data()

# check type
assert isinstance(ground_truth_dict, dict)
# check it is a nested dictionary
assert all(
isinstance(frame_data, dict)
for frame_data in ground_truth_dict.values()
)

# check data types for values in nested dictionary
for frame_number, data in ground_truth_dict.items():
assert isinstance(frame_number, int)
assert isinstance(data["bbox"], np.ndarray)
assert isinstance(data["id"], np.ndarray)
assert data["bbox"].shape[1] == 4


def test_ground_truth_data_from_csv(evaluation):
def test_ground_truth_data_values(tracker_evaluate_interface):
"""Test ground truth data holds expected values."""
# Define expected ground truth data
expected_data = {
11: {
"bbox": np.array(
Expand All @@ -53,25 +60,33 @@ def test_ground_truth_data_from_csv(evaluation):
},
}

ground_truth_dict = evaluation.get_ground_truth_data()
# Get ground truth data dict
ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data()

for frame_number, expected_frame_data in expected_data.items():
assert frame_number in ground_truth_dict
# Check if ground truth data matches expected values
for expected_frame_number, expected_frame_data in expected_data.items():
# check expected key is present
assert expected_frame_number in ground_truth_dict

assert len(ground_truth_dict[frame_number]["bbox"]) == len(
# check n of bounding boxes per frame matches the expected value
assert len(ground_truth_dict[expected_frame_number]["bbox"]) == len(
expected_frame_data["bbox"]
)

# check bbox arrays match the expected values
for bbox, expected_bbox in zip(
ground_truth_dict[frame_number]["bbox"],
ground_truth_dict[expected_frame_number]["bbox"],
expected_frame_data["bbox"],
):
assert np.allclose(
bbox, expected_bbox
), f"Frame {frame_number}, bbox mismatch"
), f"Frame {expected_frame_number}, bbox mismatch"

# check id arrays match the expected values
assert np.array_equal(
ground_truth_dict[frame_number]["id"], expected_frame_data["id"]
), f"Frame {frame_number}, id mismatch"
ground_truth_dict[expected_frame_number]["id"],
expected_frame_data["id"],
), f"Frame {expected_frame_number}, id mismatch"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -223,11 +238,19 @@ def test_ground_truth_data_from_csv(evaluation):
],
)
def test_count_identity_switches(
evaluation, prev_frame_id_map, current_frame_id_map, expected_output
tracker_evaluate_interface,
prev_frame_id_map,
current_frame_id_map,
expected_output,
):
evaluation.last_known_predicted_ids = {1: 11, 2: 12, 3: 13, 4: 14}
tracker_evaluate_interface.last_known_predicted_ids = {
1: 11,
2: 12,
3: 13,
4: 14,
}
assert (
evaluation.count_identity_switches(
tracker_evaluate_interface.count_identity_switches(
prev_frame_id_map, current_frame_id_map
)
== expected_output
Expand All @@ -243,11 +266,11 @@ def test_count_identity_switches(
([0, 0, 10, 10], [5, 15, 15, 25], 0.0),
],
)
def test_calculate_iou(box1, box2, expected_iou, evaluation):
def test_calculate_iou(box1, box2, expected_iou, tracker_evaluate_interface):
box1 = np.array(box1)
box2 = np.array(box2)

iou = evaluation.calculate_iou(box1, box2)
iou = tracker_evaluate_interface.calculate_iou(box1, box2)

# Check if IoU matches expected value
assert iou == pytest.approx(expected_iou, abs=1e-2)
Expand Down Expand Up @@ -484,7 +507,7 @@ def test_evaluate_mota(
pred_data,
prev_frame_id_map,
expected_output,
evaluation,
tracker_evaluate_interface,
):
(
mota,
Expand All @@ -494,7 +517,7 @@ def test_evaluate_mota(
num_switches,
total_gt,
_,
) = evaluation.evaluate_mota(
) = tracker_evaluate_interface.evaluate_mota(
gt_data,
pred_data,
0.1, # iou_threshold
Expand Down

0 comments on commit a570f69

Please sign in to comment.