Skip to content

Commit

Permalink
Merge branch 'develop' into talmo/check-for-tracks-before-training-id…
Browse files Browse the repository at this point in the history
…-models
  • Loading branch information
talmo authored Dec 16, 2024
2 parents d233bea + 0042cc2 commit e154749
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 16 deletions.
Binary file added docs/_static/bonsai-connection.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-filecapture.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictcentroids.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictposeidentities.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictposes.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-workflow.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 75 additions & 0 deletions docs/guides/bonsai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
(bonsai)=

# Using Bonsai with SLEAP

Bonsai is a visual language for reactive programming and currently supports SLEAP models.

:::{note}
Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models.
:::

### Exporting a SLEAP trained model

Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows:

```bash
sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path
```

Please refer to the {ref}`sleap-export` docs for more details on using the command.

This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder.

The `exported/model/path` folder will have a structure like the following:

```plaintext
exported/model/path
├── centroid_config.json
├── confmap_config.json
├── frozen_graph.pb
└── info.json
```

### Installing Bonsai and necessary packages

1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html).

2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information.

### Using Bonsai SLEAP modules

Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module.

![Bonsai FileCapture module](../_static/bonsai-filecapture.jpg)

#### Top-down model
The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules.

The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictCentroids module](../_static/bonsai-predictcentroids.jpg)

The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictPoses module](../_static/bonsai-predictposes.jpg)

#### Top-Down-ID model
The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictPoseIdentities module](../_static/bonsai-predictposeidentities.jpg)

#### Single instance model
The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

### Connecting the modules
Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection.

![Bonsai module connection ](../_static/bonsai-connection.jpg)

Once it is done, the workflow in Bonsai will look something like the following:

![Bonsai.SLEAP workflow](../_static/bonsai-workflow.jpg)

Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai.

For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html).
5 changes: 5 additions & 0 deletions docs/guides/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

{ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**.

## SLEAP with Bonsai

{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis.

```{toctree}
:hidden: true
:maxdepth: 2
Expand All @@ -44,4 +48,5 @@ proofreading
colab
custom-training
remote
bonsai
```
6 changes: 5 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ def new_instance_menu_action():
"Point Displacement (max)",
"Primary Point Displacement (sum)",
"Primary Point Displacement (max)",
"Tracking Score (mean)",
"Tracking Score (min)",
"Instance Score (sum)",
"Instance Score (min)",
"Point Score (sum)",
Expand Down Expand Up @@ -1406,6 +1408,8 @@ def _set_seekbar_header(self, graph_name: str):
"Point Displacement (max)": data_obj.get_point_displacement_series,
"Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series,
"Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series,
"Tracking Score (mean)": data_obj.get_tracking_score_series,
"Tracking Score (min)": data_obj.get_tracking_score_series,
"Instance Score (sum)": data_obj.get_instance_score_series,
"Instance Score (min)": data_obj.get_instance_score_series,
"Point Score (sum)": data_obj.get_point_score_series,
Expand All @@ -1419,7 +1423,7 @@ def _set_seekbar_header(self, graph_name: str):
else:
if graph_name in header_functions:
kwargs = dict(video=self.state["video"])
reduction_name = re.search("\\((sum|max|min)\\)", graph_name)
reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name)
if reduction_name is not None:
kwargs["reduction"] = reduction_name.group(1)
series = header_functions[graph_name](**kwargs)
Expand Down
44 changes: 38 additions & 6 deletions sleap/info/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StatisticSeries:
are frame index and value are some numerical value for the frame.
Args:
labels: The :class:`Labels` for which to calculate series.
labels: The `Labels` for which to calculate series.
"""

labels: Labels
Expand All @@ -41,7 +41,7 @@ def get_point_score_series(
"""Get series with statistic of point scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
Expand All @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]:
"""Get series with statistic of instance scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
Expand All @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo
same track) from the closest earlier labeled frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_primary_point_displacement_series(
Get sum of displacement for single node of each instance per frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
Expand Down Expand Up @@ -226,7 +226,7 @@ def _calculate_frame_velocity(
Calculate total point displacement between two given frames.
Args:
lf: The :class:`LabeledFrame` for which we want velocity
lf: The `LabeledFrame` for which we want velocity
last_lf: The frame from which to calculate displacement.
reduce_function: Numpy function (e.g., np.sum, np.nanmean)
is applied to *point* displacement, and then those
Expand All @@ -246,3 +246,35 @@ def _calculate_frame_velocity(
inst_dist = reduce_function(point_dist)
val += inst_dist if not np.isnan(inst_dist) else 0
return val

def get_tracking_score_series(
self, video: Video, reduction: str = "min"
) -> Dict[int, float]:
"""Get series with statistic of tracking scores in each frame.
Args:
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* mean
* min
Returns:
The series dictionary (see class docs for details)
"""
reduce_fn = {
"min": np.nanmin,
"mean": np.nanmean,
}[reduction]

series = dict()

for lf in self.labels.find(video):
vals = [
inst.tracking_score for inst in lf if hasattr(inst, "tracking_score")
]
if vals:
val = reduce_fn(vals)
if not np.isnan(val):
series[lf.frame_idx] = val

return series
27 changes: 24 additions & 3 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray:
return self.points_and_scores_array[:, 2]

@classmethod
def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
def from_instance(
cls, instance: Instance, score: float, tracking_score: float = 0.0
) -> "PredictedInstance":
"""Create a `PredictedInstance` from an `Instance`.
The fields are copied in a shallow manner with the exception of points. For each
Expand All @@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
Args:
instance: The `Instance` object to shallow copy data from.
score: The score for this instance.
tracking_score: The tracking score for this instance.
Returns:
A `PredictedInstance` for the given `Instance`.
Expand All @@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
)
kw_args["points"] = PredictedPointArray.from_array(instance._points)
kw_args["score"] = score
kw_args["tracking_score"] = tracking_score
return cls(**kw_args)

@classmethod
Expand All @@ -1080,6 +1084,7 @@ def from_arrays(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
Expand All @@ -1094,6 +1099,7 @@ def from_arrays(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
Expand All @@ -1114,6 +1120,7 @@ def from_arrays(
skeleton=skeleton,
score=instance_score,
track=track,
tracking_score=tracking_score,
)

@classmethod
Expand All @@ -1124,6 +1131,7 @@ def from_pointsarray(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
Expand All @@ -1138,12 +1146,18 @@ def from_pointsarray(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
"""
return cls.from_arrays(
points, point_confidences, instance_score, skeleton, track=track
points,
point_confidences,
instance_score,
skeleton,
track=track,
tracking_score=tracking_score,
)

@classmethod
Expand All @@ -1154,6 +1168,7 @@ def from_numpy(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
Expand All @@ -1168,12 +1183,18 @@ def from_numpy(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
"""
return cls.from_arrays(
points, point_confidences, instance_score, skeleton, track=track
points,
point_confidences,
instance_score,
skeleton,
track=track,
tracking_score=tracking_score,
)


Expand Down
21 changes: 16 additions & 5 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,9 +3778,10 @@ def _object_builder():
PredictedInstance.from_numpy(
points=pts,
point_confidences=confs,
instance_score=np.nanmean(score),
instance_score=np.nanmean(confs),
skeleton=skeleton,
track=track,
tracking_score=np.nanmean(score),
)
)

Expand Down Expand Up @@ -4452,18 +4453,27 @@ def _object_builder():
break

# Loop over frames.
for image, video_ind, frame_ind, points, confidences, scores in zip(
for (
image,
video_ind,
frame_ind,
centroid_vals,
points,
confidences,
scores,
) in zip(
ex["image"],
ex["video_ind"],
ex["frame_ind"],
ex["centroid_vals"],
ex["instance_peaks"],
ex["instance_peak_vals"],
ex["instance_scores"],
):
# Loop over instances.
predicted_instances = []
for i, (pts, confs, score) in enumerate(
zip(points, confidences, scores)
for i, (pts, centroid_val, confs, score) in enumerate(
zip(points, centroid_vals, confidences, scores)
):
if np.isnan(pts).all():
continue
Expand All @@ -4474,9 +4484,10 @@ def _object_builder():
PredictedInstance.from_numpy(
points=pts,
point_confidences=confs,
instance_score=np.nanmean(score),
instance_score=centroid_val,
skeleton=skeleton,
track=track,
tracking_score=score,
)
)

Expand Down
Binary file added tests/data/tracks/clip.predictions.slp
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def min_tracks_2node_labels():
)


@pytest.fixture
def min_tracks_2node_predictions():
"""
Generated with:
```
sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4"
```
"""
return Labels.load_file(
"tests/data/tracks/clip.predictions.slp",
video_search=["tests/data/tracks/clip.mp4"],
)


@pytest.fixture
def min_tracks_13node_labels():
return Labels.load_file(
Expand Down
Loading

0 comments on commit e154749

Please sign in to comment.