Skip to content

Commit

Permalink
Merge pull request #214 from EmmaRenauld/hdf5_optim
Browse files Browse the repository at this point in the history
Clarify hdf5 use
  • Loading branch information
EmmaRenauld authored Nov 22, 2023
2 parents c684f87 + 5d6fddb commit 1db1836
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 327 deletions.
19 changes: 16 additions & 3 deletions dwi_ml/data/dataset/mri_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):
"""
raise NotImplementedError

def convert_to_tensor(self, device) -> Tensor:
def get_data_as_tensor(self, device) -> Tensor:
"""Returns the _data in the tensor format."""
raise NotImplementedError

Expand Down Expand Up @@ -85,7 +85,7 @@ def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):

return cls(data, voxres, affine)

def convert_to_tensor(self, device):
def get_data_as_tensor(self, device):
# Data is already a np.array
return self._data.to(device=device)

Expand All @@ -103,6 +103,19 @@ class LazyMRIData(MRIDataAbstract):

def __init__(self, data: Union[h5py.Group, None], voxres: np.ndarray,
affine: np.ndarray):
"""
Here the data is a hdf5 group. Accessing it will load it.
It can be loaded entirely, or from indexing. Simple indexing is quite
fast (ex, 0:300 or 0:300:2), but indexing from a list of indexes is
slow. We suggest you load it all and index it after. See here:
https://stackoverflow.com/questions/21766145/h5py-correct-way-to-slice-array-datasets
In our repo, we will use our MultiSubjectContainer's method:
get_volume_verify_cache. We always load the whole volume first.
This lazy version is still useful for a big database: We can they clear
the volume in memory before accessing another subject's.
"""
super().__init__(data, voxres, affine)

@classmethod
Expand All @@ -120,7 +133,7 @@ def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):
# All three methods below load the data.
# Data is not loaded yet, but sending it to a np.array will load it.

def convert_to_tensor(self, device):
def get_data_as_tensor(self, device):
logger.debug("Loading from hdf5 now: {}".format(self._data))
return torch.as_tensor(np.array(self._data, dtype=np.float32),
dtype=torch.float, device=device)
Expand Down
6 changes: 3 additions & 3 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def get_volume_verify_cache(self, subj_idx: int, group_idx: int,

if not was_cached:
# Either non-lazy or if lazy, data was not cached.
# Non-lazy: direct access. Lazy: this loads the data.
# Non-lazy: direct access. Lazy: this loads the whole data.
logger.debug("Getting a new volume from the dataset.")
mri_data = self.get_mri_data(subj_idx, group_idx)
mri_data_tensor = mri_data.convert_to_tensor(device)
mri_data_tensor = mri_data.get_data_as_tensor(device)

# Add to cache the tensor (on correct device)
if self.cache_size:
Expand Down Expand Up @@ -283,7 +283,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
logger.debug(" Counting streamlines")
for group in range(len(self.streamline_groups)):
subj_sft_data = subj_data.sft_data_list[group]
n_streamlines = len(subj_sft_data.streamlines)
n_streamlines = len(subj_sft_data)
self._add_streamlines_ids(n_streamlines, subj_idx, group)
lengths[group].append(subj_sft_data.lengths)
lengths_mm[group].append(subj_sft_data.lengths_mm)
Expand Down
7 changes: 4 additions & 3 deletions dwi_ml/data/dataset/single_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
from typing import List, Union

from dwi_ml.data.dataset.mri_data_containers import LazyMRIData, MRIData
from dwi_ml.data.dataset.mri_data_containers import (LazyMRIData, MRIData,
MRIDataAbstract)
from dwi_ml.data.dataset.streamline_containers import LazySFTData, SFTData
from dwi_ml.data.dataset.checks_for_groups import prepare_groups_info

Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self, volume_groups: List[str], nb_features: List[int],
self.is_lazy = None

@property
def mri_data_list(self):
def mri_data_list(self) -> List[MRIDataAbstract]:
"""Returns a list of MRIData (lazy or not)."""
raise NotImplementedError

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, subject_id: str, volume_groups: List[str],
self.is_lazy = False

@property
def mri_data_list(self):
def mri_data_list(self) -> List[MRIData]:
return self._mri_data_list

@property
Expand Down
Loading

0 comments on commit 1db1836

Please sign in to comment.