From 8663593f9344fb6c13bca5d859b69126258ef0ce Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 7 Nov 2024 18:22:21 +0000 Subject: [PATCH] Adapt tests --- tests/test_unit/test_evaluate_tracker.py | 107 ++++++++++++++++------- tests/test_unit/test_track_video.py | 49 ++++++++--- tests/test_unit/test_tracking_io.py | 88 +++++++++++++++++++ tests/test_unit/test_tracking_utils.py | 77 +++++++++------- 4 files changed, 244 insertions(+), 77 deletions(-) create mode 100644 tests/test_unit/test_tracking_io.py diff --git a/tests/test_unit/test_evaluate_tracker.py b/tests/test_unit/test_evaluate_tracker.py index a08ec8d0..416ea3be 100644 --- a/tests/test_unit/test_evaluate_tracker.py +++ b/tests/test_unit/test_evaluate_tracker.py @@ -7,22 +7,30 @@ @pytest.fixture -def evaluation(): - test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv" +def tracker_evaluate_interface(): + annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv" return TrackerEvaluate( - test_csv_file, predicted_boxes_id=[], iou_threshold=0.1 + annotations_file_csv, + predicted_boxes_dict={}, + iou_threshold=0.1, + tracking_output_dir="/path/output", ) -def test_get_ground_truth_data(evaluation): - ground_truth_dict = evaluation.get_ground_truth_data() +def test_get_ground_truth_data_structure(tracker_evaluate_interface): + """Test the loaded ground truth data has the expected structure.""" + # Get ground truth data dict + ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data() + # check type assert isinstance(ground_truth_dict, dict) + # check it is a nested dictionary assert all( isinstance(frame_data, dict) for frame_data in ground_truth_dict.values() ) + # check data types for values in nested dictionary for frame_number, data in ground_truth_dict.items(): assert isinstance(frame_number, int) assert isinstance(data["bbox"], np.ndarray) @@ -30,7 +38,9 @@ def test_get_ground_truth_data(evaluation): assert data["bbox"].shape[1] == 4 -def test_ground_truth_data_from_csv(evaluation): +def test_ground_truth_data_values(tracker_evaluate_interface): + """Test ground truth data holds expected values.""" + # Define expected ground truth data expected_data = { 11: { "bbox": np.array( @@ -50,25 +60,33 @@ def test_ground_truth_data_from_csv(evaluation): }, } - ground_truth_dict = evaluation.get_ground_truth_data() + # Get ground truth data dict + ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data() - for frame_number, expected_frame_data in expected_data.items(): - assert frame_number in ground_truth_dict + # Check if ground truth data matches expected values + for expected_frame_number, expected_frame_data in expected_data.items(): + # check expected key is present + assert expected_frame_number in ground_truth_dict - assert len(ground_truth_dict[frame_number]["bbox"]) == len( + # check n of bounding boxes per frame matches the expected value + assert len(ground_truth_dict[expected_frame_number]["bbox"]) == len( expected_frame_data["bbox"] ) + + # check bbox arrays match the expected values for bbox, expected_bbox in zip( - ground_truth_dict[frame_number]["bbox"], + ground_truth_dict[expected_frame_number]["bbox"], expected_frame_data["bbox"], ): assert np.allclose( bbox, expected_bbox - ), f"Frame {frame_number}, bbox mismatch" + ), f"Frame {expected_frame_number}, bbox mismatch" + # check id arrays match the expected values assert np.array_equal( - ground_truth_dict[frame_number]["id"], expected_frame_data["id"] - ), f"Frame {frame_number}, id mismatch" + ground_truth_dict[expected_frame_number]["id"], + expected_frame_data["id"], + ), f"Frame {expected_frame_number}, id mismatch" @pytest.mark.parametrize( @@ -220,11 +238,19 @@ def test_ground_truth_data_from_csv(evaluation): ], ) def test_count_identity_switches( - evaluation, prev_frame_id_map, current_frame_id_map, expected_output + tracker_evaluate_interface, + prev_frame_id_map, + current_frame_id_map, + expected_output, ): - evaluation.last_known_predicted_ids = {1: 11, 2: 12, 3: 13, 4: 14} + tracker_evaluate_interface.last_known_predicted_ids = { + 1: 11, + 2: 12, + 3: 13, + 4: 14, + } assert ( - evaluation.count_identity_switches( + tracker_evaluate_interface.count_identity_switches( prev_frame_id_map, current_frame_id_map ) == expected_output @@ -240,18 +266,18 @@ def test_count_identity_switches( ([0, 0, 10, 10], [5, 15, 15, 25], 0.0), ], ) -def test_calculate_iou(box1, box2, expected_iou, evaluation): +def test_calculate_iou(box1, box2, expected_iou, tracker_evaluate_interface): box1 = np.array(box1) box2 = np.array(box2) - iou = evaluation.calculate_iou(box1, box2) + iou = tracker_evaluate_interface.calculate_iou(box1, box2) # Check if IoU matches expected value assert iou == pytest.approx(expected_iou, abs=1e-2) @pytest.mark.parametrize( - "gt_data, pred_data, prev_frame_id_map, expected_mota", + "gt_data, pred_data, prev_frame_id_map, expected_output", [ # perfect tracking ( @@ -276,7 +302,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 13]), }, {1: 11, 2: 12, 3: 13}, - 1.0, + [1.0, 3, 0, 0, 0], ), ( { @@ -300,7 +326,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 13]), }, {1: 11, 12: 2, 3: np.nan}, - 1.0, + [1.0, 3, 0, 0, 0], ), # ID switch ( @@ -325,7 +351,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 14]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 0, 1], ), # missed detection ( @@ -346,7 +372,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 2, 1, 0, 0], ), # false positive ( @@ -372,7 +398,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 13, 14]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 1, 0], ), # low IOU and ID switch ( @@ -397,7 +423,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 14]), }, {1: 11, 2: 12, 3: 13}, - 0, + [0, 2, 1, 1, 1], ), # low IOU and ID switch on same box ( @@ -422,7 +448,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 14, 13]), }, {1: 11, 2: 12, 3: 13}, - 1 / 3, + [1 / 3, 2, 1, 1, 0], ), # current tracked id = prev tracked id, but prev_gt_id != current gt id ( @@ -447,7 +473,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 12, 13]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 0, 1], ), # ID swapped ( @@ -472,21 +498,34 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([11, 13, 12]), }, {1: 11, 2: 12, 3: 13}, - 1 / 3, + [1 / 3, 3, 0, 0, 2], ), ], ) -def test_evaluate_mota( +def test_compute_mota_one_frame( gt_data, pred_data, prev_frame_id_map, - expected_mota, - evaluation, + expected_output, + tracker_evaluate_interface, ): - mota, _ = evaluation.evaluate_mota( + ( + mota, + true_positives, + missed_detections, + false_positives, + num_switches, + total_gt, + _, + ) = tracker_evaluate_interface.compute_mota_one_frame( gt_data, pred_data, 0.1, # iou_threshold prev_frame_id_map, ) - assert mota == pytest.approx(expected_mota) + assert mota == pytest.approx(expected_output[0]) + assert true_positives == expected_output[1] + assert missed_detections == expected_output[2] + assert false_positives == expected_output[3] + assert num_switches == expected_output[4] + assert total_gt == (true_positives + missed_detections) diff --git a/tests/test_unit/test_track_video.py b/tests/test_unit/test_track_video.py index 7d7ffa89..3614d6bf 100644 --- a/tests/test_unit/test_track_video.py +++ b/tests/test_unit/test_track_video.py @@ -10,16 +10,17 @@ @pytest.fixture def mock_args(): - temp_dir = tempfile.mkdtemp() + tmp_dir = tempfile.mkdtemp() return Namespace( config_file="/path/to/config.yaml", video_path="/path/to/video.mp4", - trained_model_path="/path/to/model.ckpt", - output_dir=temp_dir, + trained_model_path="path/to/model.ckpt", + output_dir=tmp_dir, accelerator="gpu", annotations_file=None, save_video=None, + save_frames=None, ) @@ -28,33 +29,57 @@ def mock_args(): new_callable=mock_open, read_data="max_age: 10\nmin_hits: 3\niou_threshold: 0.1", ) -@patch("yaml.safe_load") @patch("cv2.VideoCapture") -@patch("crabs.tracker.track_video.FasterRCNN.load_from_checkpoint") -@patch("crabs.tracker.track_video.Sort") -def test_tracking_setup( - mock_sort, - mock_load_from_checkpoint, - mock_videocapture, +@patch("crabs.tracker.utils.io.get_video_parameters") +@patch("crabs.tracker.track_video.get_config_from_ckpt") +@patch("crabs.tracker.track_video.get_mlflow_parameters_from_ckpt") +# we patch where the function is looked at, see +# https://docs.python.org/3/library/unittest.mock.html#where-to-patch +@patch("yaml.safe_load") +def test_tracking_constructor( mock_yaml_load, + mock_get_mlflow_parameters_from_ckpt, + mock_get_config_from_ckpt, + mock_get_video_parameters, + mock_videocapture, mock_open, mock_args, ): + # mock reading tracking config from file mock_yaml_load.return_value = { "max_age": 10, "min_hits": 3, "iou_threshold": 0.1, } - mock_model = MagicMock() - mock_load_from_checkpoint.return_value = mock_model + # mock getting mlflow parameters from checkpoint + mock_get_mlflow_parameters_from_ckpt.return_value = { + "run_name": "trained_model_run_name", + "cli_args/experiment_name": "trained_model_expt_name", + } + + # mock getting trained model's config + mock_get_config_from_ckpt.return_value = {} + + # mock getting video parameters + mock_get_video_parameters.return_value = { + "total_frames": 614, + "frame_width": 1920, + "frame_height": 1080, + "fps": 60, + } + # mock input video as if opened correctly mock_video_capture = MagicMock() mock_video_capture.isOpened.return_value = True mock_videocapture.return_value = mock_video_capture + # instantiate tracking interface tracker = Tracking(mock_args) + # check output dir is created correctly + # TODO: add asserts for other attributes assigned in constructor assert tracker.args.output_dir == mock_args.output_dir + # delete output dir Path(mock_args.output_dir).rmdir() diff --git a/tests/test_unit/test_tracking_io.py b/tests/test_unit/test_tracking_io.py new file mode 100644 index 00000000..65c0df0f --- /dev/null +++ b/tests/test_unit/test_tracking_io.py @@ -0,0 +1,88 @@ +import csv + +import numpy as np + +from crabs.tracker.utils.io import write_tracked_detections_to_csv + + +def test_write_tracked_detections_to_csv(tmp_path): + # Create test data + csv_file_path = tmp_path / "test_output.csv" + + # Create dictionary with tracked bounding boxes for 2 frames + tracked_bboxes_dict = {} + # frame_idx = 0 + tracked_bboxes_dict[0] = { + "bboxes_tracked": np.array([[10, 20, 30, 40, 1], [50, 60, 70, 80, 2]]), + "bboxes_scores": np.array([0.9, 0.8]), + } + # frame_idx = 1 + tracked_bboxes_dict[1] = { + "bboxes_tracked": np.array([[15, 25, 35, 45, 1]]), + "bboxes_scores": np.array([0.85]), + } + frame_name_regexp = "frame_{frame_idx:08d}.png" + all_frames_size = 8888 + + # Call function + write_tracked_detections_to_csv( + csv_file_path, + tracked_bboxes_dict, + frame_name_regexp, + all_frames_size, + ) + + # Read csv file + with open(csv_file_path, newline="") as csvfile: + csv_reader = csv.reader(csvfile) + rows = list(csv_reader) + + # Expected header + expected_header = [ + "filename", + "file_size", + "file_attributes", + "region_count", + "region_id", + "region_shape_attributes", + "region_attributes", + ] + + # Expected rows + expected_rows = [ + expected_header, + [ + "frame_00000000.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":10,"y":20,"width":20,"height":20}', + '{"track":"1", "confidence":"0.9"}', + ], + [ + "frame_00000000.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":50,"y":60,"width":20,"height":20}', + '{"track":"2", "confidence":"0.8"}', + ], + [ + "frame_00000001.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":15,"y":25,"width":20,"height":20}', + '{"track":"1", "confidence":"0.85"}', + ], + ] + + # Assert the header + assert rows[0] == expected_header + + # Assert the rows + for i, expected_row in enumerate(expected_rows[1:], start=1): + assert rows[i] == expected_row diff --git a/tests/test_unit/test_tracking_utils.py b/tests/test_unit/test_tracking_utils.py index 3550f134..41a948fa 100644 --- a/tests/test_unit/test_tracking_utils.py +++ b/tests/test_unit/test_tracking_utils.py @@ -1,12 +1,10 @@ -import csv -import io - import numpy as np import pytest +import torch from crabs.tracker.utils.tracking import ( extract_bounding_box_info, - write_tracked_bbox_to_csv, + format_bbox_predictions_for_sort, ) @@ -35,32 +33,49 @@ def test_extract_bounding_box_info(): assert result == expected_result -@pytest.fixture -def csv_output(): - return io.StringIO() - - -@pytest.fixture -def csv_writer(csv_output): - return csv.writer(csv_output) - - -def test_write_tracked_bbox_to_csv(csv_writer, csv_output): - bbox = np.array([10, 20, 50, 80, 1]) - frame = np.zeros((100, 100, 3), dtype=np.uint8) - frame_name = "frame_0001.png" - pred_score = 0.900 +@pytest.mark.parametrize( + "score_threshold, expected_output", + [ + ( + 0.5, + torch.tensor( + [ + [10, 20, 30, 40, 0.9], + [50, 60, 70, 80, 0.85], + [15, 25, 35, 45, 0.8], + ] + ), + ), + ( + 0.85, + torch.tensor( + [ + [10, 20, 30, 40, 0.9], + [50, 60, 70, 80, 0.85], + ] + ), + ), + ( + 0.95, + torch.tensor([]), + ), + ], +) +def test_format_bbox_predictions_for_sort(score_threshold, expected_output): + # Define the test data + prediction = [ + { + "boxes": torch.tensor( + [[10, 20, 30, 40], [50, 60, 70, 80], [15, 25, 35, 45]] + ), + "scores": torch.tensor([0.9, 0.85, 0.8]), + } + ] - write_tracked_bbox_to_csv(bbox, frame, frame_name, csv_writer, pred_score) + # Call the function + result = format_bbox_predictions_for_sort(prediction, score_threshold) - expected_row = ( - "frame_0001.png", - 30000, - '"{""clip"":123}"', - 1, - 0, - '"{""name"":""rect"",""x"":10,""y"":20,""width"":40,""height"":60}"', - '"{""track"":""1"", ""confidence"":""0.9""}"', - ) - expected_row_str = ",".join(map(str, expected_row)) - assert csv_output.getvalue().strip() == expected_row_str + # Assert the result + assert np.array_equal( + result, expected_output + ), f"Expected {expected_output}, but got {result}"