Skip to content

Commit

Permalink
Merge branch 'develop' into shrivaths/delete-predictions-beyond-frame…
Browse files Browse the repository at this point in the history
…-limit
  • Loading branch information
shrivaths16 authored Jun 26, 2024
2 parents 09ffa04 + ebfc47b commit 8a7839a
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 35 deletions.
167 changes: 139 additions & 28 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class Instance:
from_predicted: Optional["PredictedInstance"] = attr.ib(default=None)
_points: PointArray = attr.ib(default=None)
_nodes: List = attr.ib(default=None)
frame: Union["LabeledFrame", None] = attr.ib(default=None)
frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private

# The underlying Point array type that this instances point array should be.
_point_array_type = PointArray
Expand Down Expand Up @@ -1214,6 +1214,9 @@ def unstructure_instance(x: Instance):

converter.register_unstructure_hook(Instance, unstructure_instance)
converter.register_unstructure_hook(PredictedInstance, unstructure_instance)
converter.register_unstructure_hook(
InstancesList, lambda x: [converter.unstructure(inst) for inst in x]
)

## STRUCTURE HOOKS

Expand Down Expand Up @@ -1247,6 +1250,7 @@ def structure_instances_list(x, type):
converter.register_structure_hook(
Union[List[Instance], List[PredictedInstance]], structure_instances_list
)
converter.register_structure_hook(InstancesList, structure_instances_list)

# Structure forward reference for PredictedInstance for the Instance.from_predicted
# attribute.
Expand Down Expand Up @@ -1278,6 +1282,127 @@ def structure_point_array(x, t):
return converter


class InstancesList(list):
"""A list of `Instance`s associated with a `LabeledFrame`.
This class should only be used for the `LabeledFrame.instances` attribute.
"""

def __init__(self, *args, labeled_frame: Optional["LabeledFrame"] = None):
super(InstancesList, self).__init__(*args)

# Set the labeled frame for each instance
self.labeled_frame = labeled_frame

@property
def labeled_frame(self) -> "LabeledFrame":
"""Return the `LabeledFrame` associated with this list of instances."""

return self._labeled_frame

@labeled_frame.setter
def labeled_frame(self, labeled_frame: "LabeledFrame"):
"""Set the `LabeledFrame` associated with this list of instances.
This updates the `frame` attribute on each instance.
Args:
labeled_frame: The `LabeledFrame` to associate with this list of instances.
"""

try:
# If the labeled frame is the same as the one we're setting, then skip
if self._labeled_frame == labeled_frame:
return
except AttributeError:
# Only happens on init and updates each instance.frame (even if None)
pass

# Otherwise, update the frame for each instance
self._labeled_frame = labeled_frame
for instance in self:
instance.frame = labeled_frame

def append(self, instance: Union[Instance, PredictedInstance]):
"""Append an `Instance` or `PredictedInstance` to the list, setting the frame.
Args:
item: The `Instance` or `PredictedInstance` to append to the list.
"""

if not isinstance(instance, (Instance, PredictedInstance)):
raise ValueError(
f"InstancesList can only contain Instance or PredictedInstance objects,"
f" but got {type(instance)}."
)
instance.frame = self.labeled_frame
super().append(instance)

def extend(self, instances: List[Union[PredictedInstance, Instance]]):
"""Extend the list with a list of `Instance`s or `PredictedInstance`s.
Args:
instances: A list of `Instance` or `PredictedInstance` objects to add to the
list.
Returns:
None
"""
for instance in instances:
self.append(instance)

def __delitem__(self, index):
"""Remove instance (by index), and set instance.frame to None."""

instance: Instance = self.__getitem__(index)
super().__delitem__(index)

# Modify the instance to remove reference to the frame
instance.frame = None

def insert(self, index: int, instance: Union[Instance, PredictedInstance]) -> None:
super().insert(index, instance)
instance.frame = self.labeled_frame

def __setitem__(self, index, instance: Union[Instance, PredictedInstance]):
"""Set nth instance in frame to the given instance.
Args:
index: The index of instance to replace with new instance.
value: The new instance to associate with frame.
Returns:
None.
"""
super().__setitem__(index, instance)
instance.frame = self.labeled_frame

def pop(self, index: int) -> Union[Instance, PredictedInstance]:
"""Remove and return instance at index, setting instance.frame to None."""

instance = super().pop(index)
instance.frame = None
return instance

def remove(self, instance: Union[Instance, PredictedInstance]) -> None:
"""Remove instance from list, setting instance.frame to None."""
super().remove(instance)
instance.frame = None

def clear(self) -> None:
"""Remove all instances from list, setting instance.frame to None."""
for instance in self:
instance.frame = None
super().clear()

def copy(self) -> list:
"""Return a shallow copy of the list of instances as a list.
Note: This will not return an `InstancesList` object, but a normal list.
"""
return list(self)


@attr.s(auto_attribs=True, eq=False, repr=False, str=False)
class LabeledFrame:
"""Holds labeled data for a single frame of a video.
Expand All @@ -1290,9 +1415,7 @@ class LabeledFrame:

video: Video = attr.ib()
frame_idx: int = attr.ib(converter=int)
_instances: Union[List[Instance], List[PredictedInstance]] = attr.ib(
default=attr.Factory(list)
)
_instances: InstancesList = attr.ib(default=attr.Factory(InstancesList))

def __attrs_post_init__(self):
"""Called by attrs.
Expand All @@ -1302,8 +1425,7 @@ def __attrs_post_init__(self):
"""

# Make sure all instances have a reference to this frame
for instance in self.instances:
instance.frame = self
self.instances = self._instances

def __len__(self) -> int:
"""Return number of instances associated with frame."""
Expand All @@ -1319,13 +1441,8 @@ def index(self, value: Instance) -> int:

def __delitem__(self, index):
"""Remove instance (by index) from frame."""
value = self.instances.__getitem__(index)

self.instances.__delitem__(index)

# Modify the instance to remove reference to this frame
value.frame = None

def __repr__(self) -> str:
"""Return a readable representation of the LabeledFrame."""
return (
Expand All @@ -1348,9 +1465,6 @@ def insert(self, index: int, value: Instance):
"""
self.instances.insert(index, value)

# Modify the instance to have a reference back to this frame
value.frame = self

def __setitem__(self, index, value: Instance):
"""Set nth instance in frame to the given instance.
Expand All @@ -1363,9 +1477,6 @@ def __setitem__(self, index, value: Instance):
"""
self.instances.__setitem__(index, value)

# Modify the instance to have a reference back to this frame
value.frame = self

def find(
self, track: Optional[Union[Track, int]] = -1, user: bool = False
) -> List[Instance]:
Expand Down Expand Up @@ -1393,7 +1504,7 @@ def instances(self) -> List[Instance]:
return self._instances

@instances.setter
def instances(self, instances: List[Instance]):
def instances(self, instances: Union[InstancesList, List[Instance]]):
"""Set the list of instances associated with this frame.
Updates the `frame` attribute on each instance to the
Expand All @@ -1408,9 +1519,11 @@ def instances(self, instances: List[Instance]):
None
"""

# Make sure to set the frame for each instance to this LabeledFrame
for instance in instances:
instance.frame = self
# Make sure to set the LabeledFrame for each instance to this frame
if isinstance(instances, InstancesList):
instances.labeled_frame = self
else:
instances = InstancesList(instances, labeled_frame=self)

self._instances = instances

Expand Down Expand Up @@ -1685,22 +1798,20 @@ def complex_frame_merge(
* list of conflicting instances from base
* list of conflicting instances from new
"""
merged_instances = []
redundant_instances = []
extra_base_instances = copy(base_frame.instances)
extra_new_instances = []
merged_instances: List[Instance] = [] # Only used for informing user
redundant_instances: List[Instance] = []
extra_base_instances: List[Instance] = list(base_frame.instances)
extra_new_instances: List[Instance] = []

for new_inst in new_frame:
redundant = False
for base_inst in base_frame.instances:
if new_inst.matches(base_inst):
base_inst.frame = None
extra_base_instances.remove(base_inst)
redundant_instances.append(base_inst)
redundant = True
continue
if not redundant:
new_inst.frame = None
extra_new_instances.append(new_inst)

conflict = False
Expand Down Expand Up @@ -1732,7 +1843,7 @@ def complex_frame_merge(
else:
# No conflict, so include all instances in base
base_frame.instances.extend(extra_new_instances)
merged_instances = copy(extra_new_instances)
merged_instances: List[Instance] = copy(extra_new_instances)
extra_base_instances = []
extra_new_instances = []

Expand Down
Loading

0 comments on commit 8a7839a

Please sign in to comment.