Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base extractor #367

Open
wants to merge 5 commits into
base: segmentation
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/roiextractors/baseextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from abc import ABC, abstractmethod
from typing import Union, Tuple
from copy import deepcopy
import numpy as np
from .extraction_tools import ArrayType, FloatType


class BaseExtractor(ABC):

def __init__(self):
self._times = None

@abstractmethod
def get_image_size(self) -> Tuple[int, int]:
"""Get the size of each image in the recording (num_rows, num_columns).

Returns
-------
image_size: tuple
Size of each image (num_rows, num_columns).
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the recording.

Returns
-------
num_frames: int
Number of frames in the recording.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency of the recording in Hz.

Returns
-------
sampling_frequency: float
Sampling frequency of the recording in Hz.
"""
pass

def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.

Parameters
----------
frames: array-like
The frame or frames to be converted to times.

Returns
-------
times: float or array-like
The corresponding times in seconds.
"""
# Default implementation
frames = np.asarray(frames)
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert a user-inputted times (in seconds) to a frame indices.

Parameters
----------
times: array-like
The times (in seconds) to be converted to frame indices.

Returns
-------
frames: float or array-like
The corresponding frame indices.
"""
# Default implementation
times = np.asarray(times)
if self._times is None:
return np.round(times * self.get_sampling_frequency()).astype("int64")
else:
return np.searchsorted(self._times, times).astype("int64")

def set_times(self, times: ArrayType) -> None:
"""Set the recording times (in seconds) for each frame.

Parameters
----------
times: array-like
The times in seconds for each frame
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times).astype("float64")

def has_time_vector(self) -> bool:
"""Detect if the ImagingExtractor has a time vector set or not.

Returns
-------
has_times: bool
True if the ImagingExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def copy_times(self, extractor) -> None:
"""Copy times from another extractor.

Parameters
----------
extractor
The extractor from which the times will be copied.
"""
if extractor._times is not None:
self.set_times(deepcopy(extractor._times))
116 changes: 3 additions & 113 deletions src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,19 @@
Class to get a lazy frame slice.
"""

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Union, Optional, Tuple, get_args
from copy import deepcopy

import numpy as np

from .baseextractor import BaseExtractor
from .extraction_tools import ArrayType, PathType, DtypeType, FloatType, IntType


class ImagingExtractor(ABC):
class ImagingExtractor(BaseExtractor):
"""Abstract class that contains all the meta-data and input data from the imaging data."""

def __init__(self, *args, **kwargs) -> None:
"""Initialize the ImagingExtractor object."""
self._args = args
self._kwargs = kwargs
self._times = None

@abstractmethod
def get_image_size(self) -> Tuple[int, int]:
"""Get the size of the video (num_rows, num_columns).

Returns
-------
image_size: tuple
Size of the video (num_rows, num_columns).
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the video.

Returns
-------
num_frames: int
Number of frames in the video.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.

Returns
-------
sampling_frequency: float
Sampling frequency in Hz.
"""
pass

@abstractmethod
def get_dtype(self) -> DtypeType:
"""Get the data type of the video.
Expand Down Expand Up @@ -175,78 +137,6 @@ def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, in

return start_frame, end_frame

def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.

Parameters
----------
frames: array-like
The frame or frames to be converted to times.

Returns
-------
times: float or array-like
The corresponding times in seconds.
"""
# Default implementation
frames = np.asarray(frames)
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert a user-inputted times (in seconds) to a frame indices.

Parameters
----------
times: array-like
The times (in seconds) to be converted to frame indices.

Returns
-------
frames: float or array-like
The corresponding frame indices.
"""
# Default implementation
times = np.asarray(times)
if self._times is None:
return np.round(times * self.get_sampling_frequency()).astype("int64")
else:
return np.searchsorted(self._times, times).astype("int64")

def set_times(self, times: ArrayType) -> None:
"""Set the recording times (in seconds) for each frame.

Parameters
----------
times: array-like
The times in seconds for each frame
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times).astype("float64")

def has_time_vector(self) -> bool:
"""Detect if the ImagingExtractor has a time vector set or not.

Returns
-------
has_times: bool
True if the ImagingExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def copy_times(self, extractor) -> None:
"""Copy times from another extractor.

Parameters
----------
extractor
The extractor from which the times will be copied.
"""
if extractor._times is not None:
self.set_times(deepcopy(extractor._times))

def __eq__(self, imaging_extractor2):
image_size_equal = self.get_image_size() == imaging_extractor2.get_image_size()
num_frames_equal = self.get_num_frames() == imaging_extractor2.get_num_frames()
Expand Down
86 changes: 3 additions & 83 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
Class to get a lazy frame slice.
"""

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Union, Optional, Tuple, Iterable, List, get_args

import numpy as np
from numpy.typing import ArrayLike

from .baseextractor import BaseExtractor
from .extraction_tools import ArrayType, IntType, FloatType
from .extraction_tools import _pixel_mask_extractor


class SegmentationExtractor(ABC):
class SegmentationExtractor(BaseExtractor):
"""Abstract segmentation extractor class.

An abstract class that contains all the meta-data and output data from
Expand All @@ -31,43 +32,6 @@ class SegmentationExtractor(ABC):
format specific classes that inherit from this.
"""

def __init__(self):
"""Create a new SegmentationExtractor for a specific data format (unique to each child SegmentationExtractor)."""
self._times = None

@abstractmethod
def get_image_size(self) -> ArrayType:
"""Get frame size of movie (height, width).

Returns
-------
no_rois: array_like
2-D array: image height x image width
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the recording (duration of recording).

Returns
-------
num_frames: int
Number of frames in the recording.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.

Returns
-------
sampling_frequency: float
Sampling frequency of the recording in Hz.
"""
pass

@abstractmethod
def get_roi_ids(self) -> list:
"""Get the list of ROI ids.
Expand Down Expand Up @@ -308,50 +272,6 @@ def get_summary_images(self, names: Optional[list[str]] = None) -> dict:
"""
pass

# TODO: Refactor _times methods from ImagingExtractor and SegmentationExtractor into a BaseExtractor class
def set_times(self, times: ArrayType):
"""Set the recording times in seconds for each frame.

Parameters
----------
times: array-like
The times in seconds for each frame

Notes
-----
Operates on _times attribute of the SegmentationExtractor object.
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times, dtype=np.float64)

def has_time_vector(self) -> bool:
"""Detect if the SegmentationExtractor has a time vector set or not.

Returns
-------
has_time_vector: bool
True if the SegmentationExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
"""Get the timing of frames in unit of seconds.

Parameters
----------
frames: int or array-like
The frame or frames to be converted to times

Returns
-------
times: float or array-like
The corresponding times in seconds
"""
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None):
"""Return a new ImagingExtractor ranging from the start_frame to the end_frame.

Expand Down
Loading
Loading