Skip to content

Commit

Permalink
Allow max tracking args for Kalman filter (#1986)
Browse files Browse the repository at this point in the history
* and note where `target_instance_count` is initialized

* `target_instance_count` is not available in the GUI but `max_tracks` is

* add note where `target_instance_count` is initialized

* add note since neither `target_instance_count` nor `pre_cull_to_target` are options in the GUI

* accept either max_tracks or target_instance_count for compatibility with both CLI and GUI

* TypeError: track() got an unexpected keyword argument 'img_hw' since `init_tracker` has `img_hw`

* useful print statements

* black

* np.bool is deprecated

* debug

* add params for testing kalman filter

* remove params because this function isn't used

* debugging

* test kalman filter tracking

* add documentation

* kalman filter needs node indices, simple tracking and similarity anything besides normalized

* add tests for every combination related to kalman args

* add example to documentation

* delete debug scripts

* delete print statements

* black

* add test for connect single breaks
  • Loading branch information
eberrigan authored Dec 19, 2024
1 parent 369b772 commit 66d96ce
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 14 deletions.
8 changes: 7 additions & 1 deletion docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ optional arguments:
--tracking.kf_node_indices TRACKING.KF_NODE_INDICES
For Kalman filter: Indices of nodes to track. (default: )
--tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT
For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0)
For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) Kalman filters require TRACKING.KF_NODE_INDICES, TRACKING.MAX_TRACKING and TRACKING.MAX_TRACKS or TRACKING.TARGET_INSTANCE_COUNT, TRACKING.TRACKER to be simple or simplemaxtracks, and TRACKING.SIMILARITY to not be normalized_instance.
```

#### Examples:
Expand Down Expand Up @@ -285,6 +285,12 @@ sleap-track --gpu 1 ...
sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4"
```

**9. Use Kalman tracker (not recommended since flow is preferred):**

```none
sleap-track -m "models/my_model" --tracking.similarity instance --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 --tracking.kf_init_frame_count 10 --tracking.kf_node_indices 0,1 -o "output_predictions.slp" "input_video.mp4"
```

## Dataset files

(sleap-convert)=
Expand Down
12 changes: 9 additions & 3 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,9 +1129,11 @@ def export_model(
info["predicted_tensors"] = tensors

full_model = tf.function(
lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False)
if unrag_outputs
else model(x)
lambda x: (
sleap.nn.data.utils.unrag_example(model(x), numpy=False)
if unrag_outputs
else model(x)
)
)

full_model = full_model.get_concrete_function(
Expand Down Expand Up @@ -5717,3 +5719,7 @@ def main(args: Optional[list] = None):
"To retrack on predictions, must specify tracker. "
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion sleap/nn/tracker/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def remove_second_bests_from_cost_matrix(
cost matrix with invalid matches set to specified invalid value.
"""

valid_match_mask = np.full_like(cost_matrix, True, dtype=np.bool)
valid_match_mask = np.full_like(cost_matrix, True, dtype=bool)

rows, columns = cost_matrix.shape

Expand Down
54 changes: 46 additions & 8 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class Tracker(BaseTracker):
max_tracking: bool = False # To enable maximum tracking.

cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
target_instance_count: int = 0 # TODO: deprecate
pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
robust_best_instance: float = 1.0
Expand Down Expand Up @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]):
# "tracking."
# )
self.cleaner.run(frames)
elif self.target_instance_count and self.post_connect_single_breaks:
elif (
self.target_instance_count or self.max_tracks
) and self.post_connect_single_breaks:
if not self.target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
self.target_instance_count = self.max_tracks
connect_single_track_breaks(frames, self.target_instance_count)
print("Connecting single track breaks.")

def get_name(self):
tracker_name = self.candidate_maker.__class__.__name__
Expand All @@ -850,7 +857,7 @@ def make_tracker_by_name(
of_max_levels: int = 3,
save_shifted_instances: bool = False,
# Pre-tracking options to cull instances
target_instance_count: int = 0,
target_instance_count: int = 0, # TODO: deprecate target_instance_count
pre_cull_to_target: bool = False,
pre_cull_iou_threshold: Optional[float] = None,
# Post-tracking options to connect broken tracks
Expand Down Expand Up @@ -921,6 +928,7 @@ def make_tracker_by_name(

pre_cull_function = None
if target_instance_count and pre_cull_to_target:
# Right now this is not accessible from the GUI

def pre_cull_function(inst_list):
cull_frame_instances(
Expand All @@ -940,11 +948,34 @@ def pre_cull_function(inst_list):
pre_cull_function=pre_cull_function,
max_tracking=max_tracking,
max_tracks=max_tracks,
target_instance_count=target_instance_count,
target_instance_count=target_instance_count, # TODO: deprecate target_instance_count
post_connect_single_breaks=post_connect_single_breaks,
)

if target_instance_count and kf_init_frame_count:
# Kalman filter requires deprecated target_instance_count
if (max_tracks or target_instance_count) and kf_init_frame_count:
if not kf_node_indices:
raise ValueError(
"Kalman filter requires node indices for instance tracking."
)

if tracker == "flow" or tracker == "flowmaxtracks":
# Tracking with Kalman filter requires initial tracker object to be simple
raise ValueError(
"Kalman filter requires simple tracker for initial tracking."
)

if similarity == "normalized_instance":
# Kalman filter doesnot support normalized_instance_similarity
raise ValueError(
"Kalman filter does not support normalized_instance_similarity."
)

if not target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
target_instance_count = max_tracks

kalman_obj = KalmanTracker.make_tracker(
init_tracker=tracker_obj,
init_frame_count=kf_init_frame_count,
Expand All @@ -954,8 +985,10 @@ def pre_cull_function(inst_list):
)

return kalman_obj
elif kf_init_frame_count and not target_instance_count:
raise ValueError("Kalman filter requires target instance count.")
elif kf_init_frame_count and not (max_tracks or target_instance_count):
raise ValueError(
"Kalman filter requires max tracks or target instance count."
)
else:
return tracker_obj

Expand Down Expand Up @@ -1369,6 +1402,10 @@ def cull_function(inst_list):
if init_tracker.pre_cull_function is None:
init_tracker.pre_cull_function = cull_function

print(
f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters."
)

return cls(
init_tracker=init_tracker,
kalman_tracker=kalman_tracker,
Expand All @@ -1386,6 +1423,7 @@ def track(
untracked_instances: List[InstanceType],
img: Optional[np.ndarray] = None,
t: int = None,
**kwargs,
) -> List[InstanceType]:
"""Tracks individual frame, using Kalman filters if possible."""

Expand Down Expand Up @@ -1420,7 +1458,7 @@ def track(
# Initialize the Kalman filters
self.kalman_tracker.init_filters(self.init_set.instances)

# print(f"Kalman filters initialized (frame {t})")
print(f"Kalman filters initialized (frame {t})")

# Clear the data used to init filters, so that if the filters
# stop tracking and we need to re-init, we won't re-use the
Expand Down
194 changes: 193 additions & 1 deletion tests/nn/test_tracking_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,205 @@
import operator
import os
import time

import pytest
import sleap
from sleap.nn.inference import main as inference_cli
import sleap.nn.tracker.components
from sleap.io.dataset import Labels, LabeledFrame


similarity_args = [
"instance",
"normalized_instance",
"object_keypoint",
"centroid",
"iou",
]
match_args = ["hungarian", "greedy"]


@pytest.mark.parametrize(
"tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"]
)
@pytest.mark.parametrize("similarity", similarity_args)
@pytest.mark.parametrize("match", match_args)
def test_kalman_tracker(
tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match
):

if tracker_name == "flow" or tracker_name == "flowmaxtracks":
# Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker
with pytest.raises(
ValueError,
match="Kalman filter requires simple tracker for initial tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
else:
# For simple or simplemaxtracks, continue with other tests
# Check for ValueError when similarity is "normalized_instance"
if similarity == "normalized_instance":
with pytest.raises(
ValueError,
match="Kalman filter does not support normalized_instance_similarity.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
return

# Check for ValueError when kf_node_indices is None which is the default
with pytest.raises(
ValueError,
match="Kalman filter requires node indices for instance tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test for missing max_tracks and target_instance_count with kf_init_frame_count
with pytest.raises(
ValueError,
match="Kalman filter requires max tracks or target instance count.",
):
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test with target_instance_count and without max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp")
assert len(labels.tracks) == 2

# Test with target_instance_count and with max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_iou_threshold", "0.8"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_iou_threshold 0.8 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_to_target", "1"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_to_target 1 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp"
)
assert len(labels.tracks) == 2

# Test with 'tracking.post_connect_single_breaks': 0
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.post_connect_single_breaks 0 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp"
)
assert len(labels.tracks) == 2


def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
cli = (
"--tracking.tracker simple "
Expand Down

0 comments on commit 66d96ce

Please sign in to comment.