Skip to content

Commit

Permalink
Adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Nov 7, 2024
1 parent 3913cee commit 8663593
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 77 deletions.
107 changes: 73 additions & 34 deletions tests/test_unit/test_evaluate_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,40 @@


@pytest.fixture
def evaluation():
test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv"
def tracker_evaluate_interface():
annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv"
return TrackerEvaluate(
test_csv_file, predicted_boxes_id=[], iou_threshold=0.1
annotations_file_csv,
predicted_boxes_dict={},
iou_threshold=0.1,
tracking_output_dir="/path/output",
)


def test_get_ground_truth_data(evaluation):
ground_truth_dict = evaluation.get_ground_truth_data()
def test_get_ground_truth_data_structure(tracker_evaluate_interface):
"""Test the loaded ground truth data has the expected structure."""
# Get ground truth data dict
ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data()

# check type
assert isinstance(ground_truth_dict, dict)
# check it is a nested dictionary
assert all(
isinstance(frame_data, dict)
for frame_data in ground_truth_dict.values()
)

# check data types for values in nested dictionary
for frame_number, data in ground_truth_dict.items():
assert isinstance(frame_number, int)
assert isinstance(data["bbox"], np.ndarray)
assert isinstance(data["id"], np.ndarray)
assert data["bbox"].shape[1] == 4


def test_ground_truth_data_from_csv(evaluation):
def test_ground_truth_data_values(tracker_evaluate_interface):
"""Test ground truth data holds expected values."""
# Define expected ground truth data
expected_data = {
11: {
"bbox": np.array(
Expand All @@ -50,25 +60,33 @@ def test_ground_truth_data_from_csv(evaluation):
},
}

ground_truth_dict = evaluation.get_ground_truth_data()
# Get ground truth data dict
ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data()

for frame_number, expected_frame_data in expected_data.items():
assert frame_number in ground_truth_dict
# Check if ground truth data matches expected values
for expected_frame_number, expected_frame_data in expected_data.items():
# check expected key is present
assert expected_frame_number in ground_truth_dict

assert len(ground_truth_dict[frame_number]["bbox"]) == len(
# check n of bounding boxes per frame matches the expected value
assert len(ground_truth_dict[expected_frame_number]["bbox"]) == len(
expected_frame_data["bbox"]
)

# check bbox arrays match the expected values
for bbox, expected_bbox in zip(
ground_truth_dict[frame_number]["bbox"],
ground_truth_dict[expected_frame_number]["bbox"],
expected_frame_data["bbox"],
):
assert np.allclose(
bbox, expected_bbox
), f"Frame {frame_number}, bbox mismatch"
), f"Frame {expected_frame_number}, bbox mismatch"

# check id arrays match the expected values
assert np.array_equal(
ground_truth_dict[frame_number]["id"], expected_frame_data["id"]
), f"Frame {frame_number}, id mismatch"
ground_truth_dict[expected_frame_number]["id"],
expected_frame_data["id"],
), f"Frame {expected_frame_number}, id mismatch"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -220,11 +238,19 @@ def test_ground_truth_data_from_csv(evaluation):
],
)
def test_count_identity_switches(
evaluation, prev_frame_id_map, current_frame_id_map, expected_output
tracker_evaluate_interface,
prev_frame_id_map,
current_frame_id_map,
expected_output,
):
evaluation.last_known_predicted_ids = {1: 11, 2: 12, 3: 13, 4: 14}
tracker_evaluate_interface.last_known_predicted_ids = {
1: 11,
2: 12,
3: 13,
4: 14,
}
assert (
evaluation.count_identity_switches(
tracker_evaluate_interface.count_identity_switches(
prev_frame_id_map, current_frame_id_map
)
== expected_output
Expand All @@ -240,18 +266,18 @@ def test_count_identity_switches(
([0, 0, 10, 10], [5, 15, 15, 25], 0.0),
],
)
def test_calculate_iou(box1, box2, expected_iou, evaluation):
def test_calculate_iou(box1, box2, expected_iou, tracker_evaluate_interface):
box1 = np.array(box1)
box2 = np.array(box2)

iou = evaluation.calculate_iou(box1, box2)
iou = tracker_evaluate_interface.calculate_iou(box1, box2)

# Check if IoU matches expected value
assert iou == pytest.approx(expected_iou, abs=1e-2)


@pytest.mark.parametrize(
"gt_data, pred_data, prev_frame_id_map, expected_mota",
"gt_data, pred_data, prev_frame_id_map, expected_output",
[
# perfect tracking
(
Expand All @@ -276,7 +302,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 13]),
},
{1: 11, 2: 12, 3: 13},
1.0,
[1.0, 3, 0, 0, 0],
),
(
{
Expand All @@ -300,7 +326,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 13]),
},
{1: 11, 12: 2, 3: np.nan},
1.0,
[1.0, 3, 0, 0, 0],
),
# ID switch
(
Expand All @@ -325,7 +351,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 14]),
},
{1: 11, 2: 12, 3: 13},
2 / 3,
[2 / 3, 3, 0, 0, 1],
),
# missed detection
(
Expand All @@ -346,7 +372,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12]),
},
{1: 11, 2: 12, 3: 13},
2 / 3,
[2 / 3, 2, 1, 0, 0],
),
# false positive
(
Expand All @@ -372,7 +398,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 13, 14]),
},
{1: 11, 2: 12, 3: 13},
2 / 3,
[2 / 3, 3, 0, 1, 0],
),
# low IOU and ID switch
(
Expand All @@ -397,7 +423,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 14]),
},
{1: 11, 2: 12, 3: 13},
0,
[0, 2, 1, 1, 1],
),
# low IOU and ID switch on same box
(
Expand All @@ -422,7 +448,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 14, 13]),
},
{1: 11, 2: 12, 3: 13},
1 / 3,
[1 / 3, 2, 1, 1, 0],
),
# current tracked id = prev tracked id, but prev_gt_id != current gt id
(
Expand All @@ -447,7 +473,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 12, 13]),
},
{1: 11, 2: 12, 3: 13},
2 / 3,
[2 / 3, 3, 0, 0, 1],
),
# ID swapped
(
Expand All @@ -472,21 +498,34 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation):
"id": np.array([11, 13, 12]),
},
{1: 11, 2: 12, 3: 13},
1 / 3,
[1 / 3, 3, 0, 0, 2],
),
],
)
def test_evaluate_mota(
def test_compute_mota_one_frame(
gt_data,
pred_data,
prev_frame_id_map,
expected_mota,
evaluation,
expected_output,
tracker_evaluate_interface,
):
mota, _ = evaluation.evaluate_mota(
(
mota,
true_positives,
missed_detections,
false_positives,
num_switches,
total_gt,
_,
) = tracker_evaluate_interface.compute_mota_one_frame(
gt_data,
pred_data,
0.1, # iou_threshold
prev_frame_id_map,
)
assert mota == pytest.approx(expected_mota)
assert mota == pytest.approx(expected_output[0])
assert true_positives == expected_output[1]
assert missed_detections == expected_output[2]
assert false_positives == expected_output[3]
assert num_switches == expected_output[4]
assert total_gt == (true_positives + missed_detections)
49 changes: 37 additions & 12 deletions tests/test_unit/test_track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@

@pytest.fixture
def mock_args():
temp_dir = tempfile.mkdtemp()
tmp_dir = tempfile.mkdtemp()

return Namespace(
config_file="/path/to/config.yaml",
video_path="/path/to/video.mp4",
trained_model_path="/path/to/model.ckpt",
output_dir=temp_dir,
trained_model_path="path/to/model.ckpt",
output_dir=tmp_dir,
accelerator="gpu",
annotations_file=None,
save_video=None,
save_frames=None,
)


Expand All @@ -28,33 +29,57 @@ def mock_args():
new_callable=mock_open,
read_data="max_age: 10\nmin_hits: 3\niou_threshold: 0.1",
)
@patch("yaml.safe_load")
@patch("cv2.VideoCapture")
@patch("crabs.tracker.track_video.FasterRCNN.load_from_checkpoint")
@patch("crabs.tracker.track_video.Sort")
def test_tracking_setup(
mock_sort,
mock_load_from_checkpoint,
mock_videocapture,
@patch("crabs.tracker.utils.io.get_video_parameters")
@patch("crabs.tracker.track_video.get_config_from_ckpt")
@patch("crabs.tracker.track_video.get_mlflow_parameters_from_ckpt")
# we patch where the function is looked at, see
# https://docs.python.org/3/library/unittest.mock.html#where-to-patch
@patch("yaml.safe_load")
def test_tracking_constructor(
mock_yaml_load,
mock_get_mlflow_parameters_from_ckpt,
mock_get_config_from_ckpt,
mock_get_video_parameters,
mock_videocapture,
mock_open,
mock_args,
):
# mock reading tracking config from file
mock_yaml_load.return_value = {
"max_age": 10,
"min_hits": 3,
"iou_threshold": 0.1,
}

mock_model = MagicMock()
mock_load_from_checkpoint.return_value = mock_model
# mock getting mlflow parameters from checkpoint
mock_get_mlflow_parameters_from_ckpt.return_value = {
"run_name": "trained_model_run_name",
"cli_args/experiment_name": "trained_model_expt_name",
}

# mock getting trained model's config
mock_get_config_from_ckpt.return_value = {}

# mock getting video parameters
mock_get_video_parameters.return_value = {
"total_frames": 614,
"frame_width": 1920,
"frame_height": 1080,
"fps": 60,
}

# mock input video as if opened correctly
mock_video_capture = MagicMock()
mock_video_capture.isOpened.return_value = True
mock_videocapture.return_value = mock_video_capture

# instantiate tracking interface
tracker = Tracking(mock_args)

# check output dir is created correctly
# TODO: add asserts for other attributes assigned in constructor
assert tracker.args.output_dir == mock_args.output_dir

# delete output dir
Path(mock_args.output_dir).rmdir()
Loading

0 comments on commit 8663593

Please sign in to comment.