Skip to content

Commit

Permalink
Add function to filter oob
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 20, 2024
1 parent 889fc5a commit 768ca90
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
21 changes: 10 additions & 11 deletions sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, List, Text
import sleap
from sleap.nn.config import InstanceCroppingConfig
from sleap.nn.data.utils import filter_oob_points


def find_instance_crop_size(
Expand Down Expand Up @@ -43,22 +44,20 @@ def find_instance_crop_size(
min_crop_size_no_pad = min_crop_size - padding
max_length = 0.0
for lf in labels:
for inst in lf.user_instances:
pts = inst.points_array
for inst in lf:
if isinstance(inst, sleap.PredictedInstance):
continue

pts[pts < 0] = np.NaN
height, width = lf.image.shape[:2]
pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN
pts = filter_oob_points(inst.numpy(), lf.image.shape[:2])

pts *= input_scaling
max_length = np.maximum(
max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
max_length: float = np.nanmax(
[max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])]
)
max_length = np.maximum(
max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
max_length: float = np.nanmax(
[max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])]
)
max_length = np.maximum(max_length, min_crop_size_no_pad)
max_length: float = np.nanmax([max_length, min_crop_size_no_pad])

max_length += float(padding)
crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride
Expand Down
9 changes: 2 additions & 7 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Text, Optional, List, Sequence, Union, Tuple
import sleap
from sleap.instance import Instance
from sleap.nn.data.utils import filter_oob_points


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -199,8 +200,6 @@ def py_fetch_lf(ind):
raw_image = lf.image
raw_image_size = np.array(raw_image.shape).astype("int32")

height, width = raw_image_size[:2]

if self.user_instances_only:
insts = lf.user_instances
else:
Expand All @@ -211,11 +210,7 @@ def py_fetch_lf(ind):
for inst in insts:

# Filter OOB
pts = inst.numpy()
pts[pts < 0] = np.NaN

pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN
pts = filter_oob_points(inst.numpy(), raw_image_size[:2])

instance = Instance.from_numpy(pts, inst.skeleton, inst.track)

Expand Down
10 changes: 10 additions & 0 deletions sleap/nn/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from typing import Any, List, Tuple, Dict, Text, Optional


def filter_oob_points(pts: np.ndarray, img_hw: tuple) -> np.ndarray:
"""Convert negative/ out-of-boundary pts to NaNs."""
pts[pts < 0] = np.NaN
height, width = img_hw
pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN

return pts


def ensure_list(x: Any) -> List[Any]:
"""Convert the input into a list if it is not already."""
if not isinstance(x, list):
Expand Down

0 comments on commit 768ca90

Please sign in to comment.