Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compound extractors #371

Open
wants to merge 30 commits into
base: tools
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
510b882
renamed multiimagingextractor --> FrameConcatenatedImagingExtractor
pauladkisson Oct 7, 2024
9523b5f
renamed multiimagingextractor --> FrameConcatenatedImagingExtractor
pauladkisson Oct 7, 2024
d69d968
renamed multiimagingextractor --> FrameConcatenatedImagingExtractor
pauladkisson Oct 7, 2024
0a0b68b
updated check_consistency fn
pauladkisson Oct 7, 2024
5365e44
simplified range calls in _get_times
pauladkisson Oct 7, 2024
3e7c3fd
refactored get_frames
pauladkisson Oct 7, 2024
32e10f4
refactored get_video
pauladkisson Oct 7, 2024
1c21007
added tests for frame_concatenated imaging_extractor
pauladkisson Oct 7, 2024
288ad12
refactored into a mixin
pauladkisson Oct 7, 2024
0de656e
refactored into a mixin
pauladkisson Oct 7, 2024
4e63092
added tests for concatenation points
pauladkisson Oct 7, 2024
cc86076
added tests for alternative inits
pauladkisson Oct 7, 2024
2b4db1f
made check a staticmethod
pauladkisson Oct 8, 2024
123d63d
updated consistency check to remove channel
pauladkisson Oct 8, 2024
588fa7d
updated get_video to use _validate_get_video_arguments
pauladkisson Oct 8, 2024
63550ea
updated get_frames
pauladkisson Oct 8, 2024
1ef6de2
removed channel fns
pauladkisson Oct 8, 2024
721bff7
reverted frame_slice since its compatible now
pauladkisson Oct 8, 2024
1bcaedb
added tests (failing)
pauladkisson Oct 8, 2024
1522ade
added test for check_consistency for frame_concatenated_ie
pauladkisson Oct 8, 2024
db198d2
added frame slice to volumetric
pauladkisson Oct 8, 2024
3d23dba
image_size --> 2D rather than 3D
pauladkisson Oct 8, 2024
9cd53ac
fixed get_frames
pauladkisson Oct 8, 2024
1de12b1
added frame_slice tests
pauladkisson Oct 8, 2024
9efa3ce
added depth_slice tests
pauladkisson Oct 8, 2024
fce6986
added depth_slice_not_implemented test for frame_slice
pauladkisson Oct 8, 2024
8066ae3
renamed multisegmentationextractor --> volumetricsegmentationextractor
pauladkisson Oct 8, 2024
89f97ff
reverted renaming
pauladkisson Oct 8, 2024
b0f4366
added notimplementedError
pauladkisson Oct 8, 2024
73609f6
fixed file paths for segmentation_extractor2
pauladkisson Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions old_tests/test_internals/test_multiimagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.testing import assert_array_equal
from parameterized import parameterized, param

from roiextractors.multiimagingextractor import MultiImagingExtractor
from roiextractors.frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor
from roiextractors.testing import generate_dummy_imaging_extractor


Expand All @@ -19,7 +19,7 @@ def setUpClass(cls):
generate_dummy_imaging_extractor(num_frames=10, num_rows=3, num_columns=4, sampling_frequency=20.0)
for _ in range(3)
]
cls.multi_imaging_extractor = MultiImagingExtractor(imaging_extractors=cls.extractors)
cls.multi_imaging_extractor = FrameConcatenatedImagingExtractor(imaging_extractors=cls.extractors)

def test_get_image_size(self):
assert self.multi_imaging_extractor.get_image_size() == self.extractors[0].get_image_size()
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_set_incorrect_times(self):

def test_set_times(self):
self.extractors[1].set_times(np.arange(0, 10) / 30.0)
multi_imaging_extractor = MultiImagingExtractor(imaging_extractors=self.extractors)
multi_imaging_extractor = FrameConcatenatedImagingExtractor(imaging_extractors=self.extractors)

dummy_times = np.arange(0, 30) / 20.0
to_replace = [*range(multi_imaging_extractor._start_frames[1], multi_imaging_extractor._end_frames[1])]
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_inconsistent_property_assertion(self, rows, columns, sampling_frequency
exc_type=AssertionError,
exc_msg=expected_error_msg,
):
MultiImagingExtractor(imaging_extractors=inconsistent_extractors)
FrameConcatenatedImagingExtractor(imaging_extractors=inconsistent_extractors)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/roiextractors/extractorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .extractors.memmapextractors import MemmapImagingExtractor
from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor
from .multisegmentationextractor import MultiSegmentationExtractor
from .multiimagingextractor import MultiImagingExtractor
from .frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor
from .volumetricimagingextractor import VolumetricImagingExtractor

imaging_extractor_full_list = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import numpy as np

from ...imagingextractor import ImagingExtractor
from ...multiimagingextractor import MultiImagingExtractor
from ...frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor
from ...tools.typing import PathType, DtypeType
from ...tools.importing import get_package


class MiniscopeImagingExtractor(MultiImagingExtractor): # TODO: rename to MiniscopeMultiImagingExtractor
class MiniscopeImagingExtractor(FrameConcatenatedImagingExtractor): # TODO: rename to MiniscopeMultiImagingExtractor
"""An ImagingExtractor for the Miniscope video (.avi) format.

This format consists of video (.avi) file(s) and configuration files (.json).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from ...multiimagingextractor import MultiImagingExtractor
from ...frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor
from ...imagingextractor import ImagingExtractor
from ...tools.typing import PathType, DtypeType, ArrayType
from ...tools.importing import get_package
Expand Down Expand Up @@ -119,7 +119,7 @@ def _parse_xml(folder_path: PathType) -> etree.Element:
return tree.getroot()


class BrukerTiffMultiPlaneImagingExtractor(MultiImagingExtractor):
class BrukerTiffMultiPlaneImagingExtractor(FrameConcatenatedImagingExtractor):
"""A MultiImagingExtractor for TIFF files produced by Bruke with multiple planes.

This format consists of multiple TIF image files (.ome.tif) and configuration files (.xml, .env).
Expand Down Expand Up @@ -312,7 +312,7 @@ def get_video(
return video


class BrukerTiffSinglePlaneImagingExtractor(MultiImagingExtractor):
class BrukerTiffSinglePlaneImagingExtractor(FrameConcatenatedImagingExtractor):
"""A MultiImagingExtractor for TIFF files produced by Bruker with only 1 plane."""

extractor_name = "BrukerTiffSinglePlaneImaging"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...imagingextractor import ImagingExtractor
from ...tools.typing import PathType, DtypeType
from ...tools.importing import get_package
from ...multiimagingextractor import MultiImagingExtractor
from ...frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor


def filter_tiff_tag_warnings(record):
Expand All @@ -37,7 +37,7 @@ def _get_tiff_reader() -> ModuleType:
return get_package(package_name="tifffile", installation_instructions="pip install tifffile")


class MicroManagerTiffImagingExtractor(MultiImagingExtractor):
class MicroManagerTiffImagingExtractor(FrameConcatenatedImagingExtractor):
"""Specialized extractor for reading TIFF files produced via Micro-Manager.

The image file stacks are saved into multipage TIF files in OME-TIFF format (.ome.tif files),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...tools.typing import PathType, FloatType, ArrayType, DtypeType
from ...imagingextractor import ImagingExtractor
from ...volumetricimagingextractor import VolumetricImagingExtractor
from ...multiimagingextractor import MultiImagingExtractor
from ...frameconcatenatedimagingextractor import FrameConcatenatedImagingExtractor
from .scanimagetiff_utils import (
extract_extra_metadata,
parse_metadata,
Expand All @@ -23,7 +23,7 @@
)


class ScanImageTiffMultiPlaneMultiFileImagingExtractor(MultiImagingExtractor):
class ScanImageTiffMultiPlaneMultiFileImagingExtractor(FrameConcatenatedImagingExtractor):
"""Specialized extractor for reading multi-file (buffered) TIFF files produced via ScanImage."""

extractor_name = "ScanImageTiffMultiPlaneMultiFileImaging"
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
super().__init__(imaging_extractors=imaging_extractors)


class ScanImageTiffSinglePlaneMultiFileImagingExtractor(MultiImagingExtractor):
class ScanImageTiffSinglePlaneMultiFileImagingExtractor(FrameConcatenatedImagingExtractor):
"""Specialized extractor for reading multi-file (buffered) TIFF files produced via ScanImage."""

extractor_name = "ScanImageTiffSinglePlaneMultiFileImaging"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Defines the MultiImagingExtractor class.
"""Defines the FrameConcatenatedImagingExtractor class.

Classes
-------
MultiImagingExtractor
FrameConcatenatedImagingExtractor
This class is used to combine multiple ImagingExtractor objects by frames.
"""

Expand All @@ -15,15 +15,15 @@
from .imagingextractor import ImagingExtractor


class MultiImagingExtractor(ImagingExtractor):
class FrameConcatenatedImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by frames."""

extractor_name = "MultiImagingExtractor"
extractor_name = "FrameConcatenatedImagingExtractor"
installed = True
installation_mesg = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
"""Initialize a MultiImagingExtractor object from a list of ImagingExtractors.
"""Initialize a FrameConcatenatedImagingExtractor object from a list of ImagingExtractors.

Parameters
----------
Expand All @@ -33,26 +33,32 @@ def __init__(self, imaging_extractors: List[ImagingExtractor]):
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors=imaging_extractors)
self._imaging_extractors = imaging_extractors

# Checks that properties are consistent between extractors
self._check_consistency_between_imaging_extractors()

self._start_frames, self._end_frames = [], []
num_frames = 0
for imaging_extractor in self._imaging_extractors:
self._start_frames.append(num_frames)
num_frames = num_frames + imaging_extractor.get_num_frames()
self._end_frames.append(num_frames)
self._start_frames = np.array(self._start_frames)
self._end_frames = np.array(self._end_frames)
self._num_frames = num_frames

if any((getattr(imaging_extractor, "_times") is not None for imaging_extractor in self._imaging_extractors)):
times = self._get_times()
self.set_times(times=times)

def _check_consistency_between_imaging_extractors(self):
@staticmethod
def _check_consistency_between_imaging_extractors(imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.

Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects

Raises
------
AssertionError
Expand All @@ -63,19 +69,15 @@ def _check_consistency_between_imaging_extractors(self):
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type.",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in self._imaging_extractors]
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
Expand All @@ -89,114 +91,82 @@ def _get_times(self) -> np.ndarray:
times: numpy.ndarray
Array of times.
"""
frame_indices = np.array([*range(self._start_frames[0], self._end_frames[-1])])
frame_indices = np.arange(self._num_frames)
times = self.frame_to_time(frames=frame_indices)

for extractor_index, extractor in enumerate(self._imaging_extractors):
if getattr(extractor, "_times") is not None:
to_replace = [*range(self._start_frames[extractor_index], self._end_frames[extractor_index])]
to_replace = np.arange(self._start_frames[extractor_index], self._end_frames[extractor_index])
times[to_replace] = extractor._times

return times

def _get_frames_from_an_imaging_extractor(self, extractor_index: int, frame_idxs: ArrayType) -> np.ndarray:
"""Get frames from a single imaging extractor.

Parameters
----------
extractor_index: int
Index of the imaging extractor to use.
frame_idxs: array_like
Indices of the frames to get.

Returns
-------
frames: numpy.ndarray
Array of frames.
"""
imaging_extractor = self._imaging_extractors[extractor_index]
frames = imaging_extractor.get_frames(frame_idxs=frame_idxs)
return frames

def get_dtype(self):
return self._imaging_extractors[0].get_dtype()

def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0) -> np.ndarray:
if isinstance(frame_idxs, (int, np.integer)):
frame_idxs = [frame_idxs]
frame_idxs = np.array(frame_idxs)
assert np.all(frame_idxs < self.get_num_frames()), "'frame_idxs' exceed number of frames"
def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
self._validate_get_frames_arguments(frame_idxs=frame_idxs)
extractor_indices = np.searchsorted(self._end_frames, frame_idxs, side="right")
relative_frame_indices = frame_idxs - np.array(self._start_frames)[extractor_indices]
relative_frame_indices = frame_idxs - self._start_frames[extractor_indices]

# Match frame_idxs to imaging extractors
extractors_dict = defaultdict(list)
extractor_index_to_relative_frame_indices = defaultdict(list)
for extractor_index, frame_index in zip(extractor_indices, relative_frame_indices):
extractors_dict[extractor_index].append(frame_index)
extractor_index_to_relative_frame_indices[extractor_index].append(frame_index)

frames_to_concatenate = []
# Extract frames for each extractor and concatenate
for extractor_index, frame_indices in extractors_dict.items():
frames_for_each_extractor = self._get_frames_from_an_imaging_extractor(
extractor_index=extractor_index,
frame_idxs=frame_indices,
)
if len(frame_indices) == 1:
frames_for_each_extractor = frames_for_each_extractor[np.newaxis, ...]
frames_to_concatenate.append(frames_for_each_extractor)
for extractor_index, frame_indices in extractor_index_to_relative_frame_indices.items():
imaging_extractor = self._imaging_extractors[extractor_index]
frames = imaging_extractor.get_frames(frame_idxs=frame_indices)
frames_to_concatenate.append(frames)

frames = np.concatenate(frames_to_concatenate, axis=0)
return frames

def get_video(
self, start_frame: Optional[int] = None, end_frame: Optional[int] = None, channel: int = 0
) -> np.ndarray:
if channel != 0:
raise NotImplementedError(
f"MultiImagingExtractors for multiple channels have not yet been implemented! (Received '{channel}'."
)

start = start_frame if start_frame is not None else 0
stop = end_frame if end_frame is not None else self.get_num_frames()
extractors_range = np.searchsorted(self._end_frames, (start, stop - 1), side="right")
def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
start_frame, end_frame = self._validate_get_video_arguments(start_frame=start_frame, end_frame=end_frame)
extractors_range = np.searchsorted(self._end_frames, (start_frame, end_frame - 1), side="right")
extractors_spanned = list(
range(extractors_range[0], min(extractors_range[-1] + 1, len(self._imaging_extractors)))
)

# Early return with simple relative indexing; preserves native return class of that extractor
if len(extractors_spanned) == 1:
relative_start = start - self._start_frames[extractors_spanned[0]]
relative_stop = stop - start + relative_start
extractor_index = extractors_spanned[0]
relative_start = start_frame - self._start_frames[extractor_index]
relative_stop = end_frame - start_frame + relative_start

return self._imaging_extractors[extractors_spanned[0]].get_video(
start_frame=relative_start, end_frame=relative_stop
)

video_shape = (stop - start,) + self._imaging_extractors[0].get_image_size()
video_shape = (end_frame - start_frame, *self._imaging_extractors[0].get_image_size())
video = np.empty(shape=video_shape, dtype=self.get_dtype())
current_frame = 0

# Left endpoint; since more than one extractor is spanned, only care about indexing first start frame
relative_start = start - self._start_frames[extractors_spanned[0]]
relative_span = self._end_frames[extractors_spanned[0]] - start
extractor_index = extractors_spanned[0]
relative_start = start_frame - self._start_frames[extractor_index]
relative_span = self._end_frames[extractor_index] - start_frame
array_frame_slice = slice(current_frame, relative_span)
video[array_frame_slice, ...] = self._imaging_extractors[extractors_spanned[0]].get_video(
start_frame=relative_start
)
imaging_extractor = self._imaging_extractors[extractor_index]
video[array_frame_slice, ...] = imaging_extractor.get_video(start_frame=relative_start)
current_frame += relative_span

# All inner spans can be written knowing only how long each section is
for extractor_index in extractors_spanned[1:-1]:
relative_span = self._end_frames[extractor_index] - self._start_frames[extractor_index]
array_frame_slice = slice(current_frame, current_frame + relative_span)
video[array_frame_slice, ...] = self._imaging_extractors[extractor_index].get_video()
imaging_extractor = self._imaging_extractors[extractor_index]
video[array_frame_slice, ...] = imaging_extractor.get_video()
current_frame += relative_span

# Right endpoint; since more than one extractor is spanned, only care about indexing final end frame
relative_stop = stop - self._start_frames[extractors_spanned[-1]]
relative_stop = end_frame - self._start_frames[extractors_spanned[-1]]
array_frame_slice = slice(current_frame, None)
video[array_frame_slice, ...] = self._imaging_extractors[extractors_spanned[-1]].get_video(
end_frame=relative_stop
)
imaging_extractor = self._imaging_extractors[extractors_spanned[-1]]
video[array_frame_slice, ...] = imaging_extractor.get_video(end_frame=relative_stop)

return video

Expand Down
3 changes: 3 additions & 0 deletions src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,6 @@ def get_sampling_frequency(self) -> float:

def get_dtype(self) -> DtypeType:
return self._parent_imaging.get_dtype()

def depth_slice(self, start_plane: Optional[int] = None, end_plane: Optional[int] = None):
raise NotImplementedError("Depth slicing is not supported for FrameSliceImagingExtractor.")
4 changes: 4 additions & 0 deletions src/roiextractors/multisegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(self, segmentatation_extractors_list, plane_names=None): # TODO: H
plane_names: list
list of strings of names for the plane. Defaults to 'Plane0', 'Plane1' ...
"""
raise NotImplementedError(
"Multi-Plane Segmentations are currently not supported. Please raise an issue to request this feature: "
"https://github.com/catalystneuro/roiextractors/issues "
)
SegmentationExtractor.__init__(self)
if not isinstance(segmentatation_extractors_list, list):
raise Exception("Enter a list of segmentation extractor objects as argument")
Expand Down
Loading
Loading