From 66d96ced50153e98b952176fbdb73043e105e397 Mon Sep 17 00:00:00 2001 From: Elizabeth <106755962+eberrigan@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:26:54 -0800 Subject: [PATCH] Allow max tracking args for Kalman filter (#1986) * and note where `target_instance_count` is initialized * `target_instance_count` is not available in the GUI but `max_tracks` is * add note where `target_instance_count` is initialized * add note since neither `target_instance_count` nor `pre_cull_to_target` are options in the GUI * accept either max_tracks or target_instance_count for compatibility with both CLI and GUI * TypeError: track() got an unexpected keyword argument 'img_hw' since `init_tracker` has `img_hw` * useful print statements * black * np.bool is deprecated * debug * add params for testing kalman filter * remove params because this function isn't used * debugging * test kalman filter tracking * add documentation * kalman filter needs node indices, simple tracking and similarity anything besides normalized * add tests for every combination related to kalman args * add example to documentation * delete debug scripts * delete print statements * black * add test for connect single breaks --- docs/guides/cli.md | 8 +- sleap/nn/inference.py | 12 +- sleap/nn/tracker/kalman.py | 2 +- sleap/nn/tracking.py | 54 +++++-- tests/nn/test_tracking_integration.py | 194 +++++++++++++++++++++++++- 5 files changed, 256 insertions(+), 14 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 134461c60..339c5405b 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -230,7 +230,7 @@ optional arguments: --tracking.kf_node_indices TRACKING.KF_NODE_INDICES For Kalman filter: Indices of nodes to track. (default: ) --tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT - For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) + For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) Kalman filters require TRACKING.KF_NODE_INDICES, TRACKING.MAX_TRACKING and TRACKING.MAX_TRACKS or TRACKING.TARGET_INSTANCE_COUNT, TRACKING.TRACKER to be simple or simplemaxtracks, and TRACKING.SIMILARITY to not be normalized_instance. ``` #### Examples: @@ -285,6 +285,12 @@ sleap-track --gpu 1 ... sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4" ``` +**9. Use Kalman tracker (not recommended since flow is preferred):** + +```none +sleap-track -m "models/my_model" --tracking.similarity instance --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 --tracking.kf_init_frame_count 10 --tracking.kf_node_indices 0,1 -o "output_predictions.slp" "input_video.mp4" +``` + ## Dataset files (sleap-convert)= diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 3f01a1c3c..c27382e52 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -1129,9 +1129,11 @@ def export_model( info["predicted_tensors"] = tensors full_model = tf.function( - lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False) - if unrag_outputs - else model(x) + lambda x: ( + sleap.nn.data.utils.unrag_example(model(x), numpy=False) + if unrag_outputs + else model(x) + ) ) full_model = full_model.get_concrete_function( @@ -5717,3 +5719,7 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) + + +if __name__ == "__main__": + main() diff --git a/sleap/nn/tracker/kalman.py b/sleap/nn/tracker/kalman.py index 2b0343927..774a4634e 100644 --- a/sleap/nn/tracker/kalman.py +++ b/sleap/nn/tracker/kalman.py @@ -608,7 +608,7 @@ def remove_second_bests_from_cost_matrix( cost matrix with invalid matches set to specified invalid value. """ - valid_match_mask = np.full_like(cost_matrix, True, dtype=np.bool) + valid_match_mask = np.full_like(cost_matrix, True, dtype=bool) rows, columns = cost_matrix.shape diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..231b004f5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -574,7 +574,7 @@ class Tracker(BaseTracker): max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate - target_instance_count: int = 0 + target_instance_count: int = 0 # TODO: deprecate pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False robust_best_instance: float = 1.0 @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]): # "tracking." # ) self.cleaner.run(frames) - elif self.target_instance_count and self.post_connect_single_breaks: + elif ( + self.target_instance_count or self.max_tracks + ) and self.post_connect_single_breaks: + if not self.target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + self.target_instance_count = self.max_tracks connect_single_track_breaks(frames, self.target_instance_count) + print("Connecting single track breaks.") def get_name(self): tracker_name = self.candidate_maker.__class__.__name__ @@ -850,7 +857,7 @@ def make_tracker_by_name( of_max_levels: int = 3, save_shifted_instances: bool = False, # Pre-tracking options to cull instances - target_instance_count: int = 0, + target_instance_count: int = 0, # TODO: deprecate target_instance_count pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks @@ -921,6 +928,7 @@ def make_tracker_by_name( pre_cull_function = None if target_instance_count and pre_cull_to_target: + # Right now this is not accessible from the GUI def pre_cull_function(inst_list): cull_frame_instances( @@ -940,11 +948,34 @@ def pre_cull_function(inst_list): pre_cull_function=pre_cull_function, max_tracking=max_tracking, max_tracks=max_tracks, - target_instance_count=target_instance_count, + target_instance_count=target_instance_count, # TODO: deprecate target_instance_count post_connect_single_breaks=post_connect_single_breaks, ) - if target_instance_count and kf_init_frame_count: + # Kalman filter requires deprecated target_instance_count + if (max_tracks or target_instance_count) and kf_init_frame_count: + if not kf_node_indices: + raise ValueError( + "Kalman filter requires node indices for instance tracking." + ) + + if tracker == "flow" or tracker == "flowmaxtracks": + # Tracking with Kalman filter requires initial tracker object to be simple + raise ValueError( + "Kalman filter requires simple tracker for initial tracking." + ) + + if similarity == "normalized_instance": + # Kalman filter doesnot support normalized_instance_similarity + raise ValueError( + "Kalman filter does not support normalized_instance_similarity." + ) + + if not target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + target_instance_count = max_tracks + kalman_obj = KalmanTracker.make_tracker( init_tracker=tracker_obj, init_frame_count=kf_init_frame_count, @@ -954,8 +985,10 @@ def pre_cull_function(inst_list): ) return kalman_obj - elif kf_init_frame_count and not target_instance_count: - raise ValueError("Kalman filter requires target instance count.") + elif kf_init_frame_count and not (max_tracks or target_instance_count): + raise ValueError( + "Kalman filter requires max tracks or target instance count." + ) else: return tracker_obj @@ -1369,6 +1402,10 @@ def cull_function(inst_list): if init_tracker.pre_cull_function is None: init_tracker.pre_cull_function = cull_function + print( + f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters." + ) + return cls( init_tracker=init_tracker, kalman_tracker=kalman_tracker, @@ -1386,6 +1423,7 @@ def track( untracked_instances: List[InstanceType], img: Optional[np.ndarray] = None, t: int = None, + **kwargs, ) -> List[InstanceType]: """Tracks individual frame, using Kalman filters if possible.""" @@ -1420,7 +1458,7 @@ def track( # Initialize the Kalman filters self.kalman_tracker.init_filters(self.init_set.instances) - # print(f"Kalman filters initialized (frame {t})") + print(f"Kalman filters initialized (frame {t})") # Clear the data used to init filters, so that if the filters # stop tracking and we need to re-init, we won't re-use the diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..4a601ac00 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -2,13 +2,205 @@ import operator import os import time - +import pytest import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components from sleap.io.dataset import Labels, LabeledFrame +similarity_args = [ + "instance", + "normalized_instance", + "object_keypoint", + "centroid", + "iou", +] +match_args = ["hungarian", "greedy"] + + +@pytest.mark.parametrize( + "tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"] +) +@pytest.mark.parametrize("similarity", similarity_args) +@pytest.mark.parametrize("match", match_args) +def test_kalman_tracker( + tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match +): + + if tracker_name == "flow" or tracker_name == "flowmaxtracks": + # Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker + with pytest.raises( + ValueError, + match="Kalman filter requires simple tracker for initial tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + else: + # For simple or simplemaxtracks, continue with other tests + # Check for ValueError when similarity is "normalized_instance" + if similarity == "normalized_instance": + with pytest.raises( + ValueError, + match="Kalman filter does not support normalized_instance_similarity.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + return + + # Check for ValueError when kf_node_indices is None which is the default + with pytest.raises( + ValueError, + match="Kalman filter requires node indices for instance tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test for missing max_tracks and target_instance_count with kf_init_frame_count + with pytest.raises( + ValueError, + match="Kalman filter requires max tracks or target instance count.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test with target_instance_count and without max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp") + assert len(labels.tracks) == 2 + + # Test with target_instance_count and with max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_iou_threshold", "0.8" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_iou_threshold 0.8 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_to_target", "1" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_to_target 1 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp" + ) + assert len(labels.tracks) == 2 + + # Test with 'tracking.post_connect_single_breaks': 0 + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.post_connect_single_breaks 0 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp" + ) + assert len(labels.tracks) == 2 + + def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): cli = ( "--tracking.tracker simple "