diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 319a2d007..b52a8601e 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -210,16 +210,16 @@ def py_fetch_lf(ind): for inst in insts: - if len(inst) > 0: + # Filter OOB + pts = inst.numpy() + pts[pts < 0] = np.NaN - # Filter OOB - pts = inst.numpy() - pts[pts < 0] = np.NaN + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN - pts[:, 0][pts[:, 0] > width - 1] = np.NaN - pts[:, 1][pts[:, 1] > height - 1] = np.NaN + instance = Instance.from_numpy(pts, inst.skeleton, inst.track) - instance = Instance.from_numpy(pts, inst.skeleton, inst.track) + if len(instance) > 0: if self.with_track_only: if instance.track is not None: