Skip to content

Commit

Permalink
Merge pull request #248 from EmmaRenauld/dps_in_config_file
Browse files Browse the repository at this point in the history
Dps management
  • Loading branch information
EmmaRenauld authored Nov 7, 2024
2 parents ce0a2d7 + 8173510 commit 7f0f6d9
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 50 deletions.
102 changes: 82 additions & 20 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import h5py
from nibabel.streamlines import ArraySequence
import numpy as np
from collections import defaultdict


def _load_space_attributes_from_hdf(hdf_group: h5py.Group):
def _load_streamlines_attributes_from_hdf(hdf_group: h5py.Group):
a = np.array(hdf_group.attrs['affine'])
d = np.array(hdf_group.attrs['dimensions'])
vs = np.array(hdf_group.attrs['voxel_sizes'])
Expand Down Expand Up @@ -43,7 +44,13 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
streamlines._offsets = np.array(hdf_group['offsets'])
streamlines._lengths = np.array(hdf_group['lengths'])

return streamlines
# DPS
hdf_dps_group = hdf_group['data_per_streamline']
dps_dict = {}
for dps_key in hdf_dps_group.keys():
dps_dict[dps_key] = hdf_dps_group[dps_key][:]

return streamlines, dps_dict


def _load_connectivity_info(hdf_group: h5py.Group):
Expand Down Expand Up @@ -79,40 +86,85 @@ def _get_one_streamline(self, idx: int):

return data

def _assert_dps(self, dps_dict, n_streamlines):
for key, value in dps_dict.items():
if len(value) != n_streamlines:
raise ValueError(
f"Length of data_per_streamline {key} is {len(value)} "
f"but should be {n_streamlines}.")
elif not isinstance(value, np.ndarray):
raise ValueError(
f"Data_per_streamline {key} should be a numpy array, "
f"not a {type(value)}.")

def get_array_sequence(self, item=None):
if item is None:
streamlines = _load_all_streamlines_from_hdf(self.hdf_group)
streamlines, data_per_streamline = _load_all_streamlines_from_hdf(
self.hdf_group)
else:
streamlines = ArraySequence()
data_per_streamline = defaultdict(list)

# If data_per_streamline is not in the hdf5, use an empty dict
# so that we don't add anything to the data_per_streamline in the
# following steps.
hdf_dps_group = self.hdf_group['data_per_streamline'] if \
'data_per_streamline' in self.hdf_group.keys() else {}

if isinstance(item, int):
streamlines.append(self._get_one_streamline(item))
data = self._get_one_streamline(item)
streamlines.append(data)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

elif isinstance(item, list) or isinstance(item, np.ndarray):
# Getting a list of value from a hdf5: slow. Uses fancy indexing.
# But possible. See here:
# Getting a list of value from a hdf5: slow. Uses fancy
# indexing. But possible. See here:
# https://stackoverflow.com/questions/21766145/h5py-correct-way-to-slice-array-datasets
# Looping and accessing ourselves.
# Good also load the whole data and access the indexes after.
# toDo Test speed for the three options.
for i in item:
streamlines.append(self._get_one_streamline(i),
cache_build=True)
data = self._get_one_streamline(i)
streamlines.append(data, cache_build=True)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

streamlines.finalize_append()

elif isinstance(item, slice):
offsets = self.hdf_group['offsets'][item]
lengths = self.hdf_group['lengths'][item]
for offset, length in zip(offsets, lengths):
indices = np.arange(item.start, item.stop, item.step)
for offset, length, idx in zip(offsets, lengths, indices):
streamline = self.hdf_group['data'][offset:offset + length]
streamlines.append(streamline, cache_build=True)

for dps_key in hdf_dps_group.keys():
# Indexing with a list (e.g. [idx]) will preserve the
# shape of the array. Crucial for concatenation below.
dps_data = hdf_dps_group[dps_key][[idx]]
data_per_streamline[dps_key].append(dps_data)
streamlines.finalize_append()

else:
raise ValueError('Item should be either a int, list, '
'np.ndarray or slice but we received {}'
.format(type(item)))
return streamlines

# The accumulated data_per_streamline is a list of numpy arrays.
# We need to merge them into a single numpy array so it can be
# reused in the StatefulTractogram.
for key in data_per_streamline.keys():
data_per_streamline[key] = \
np.concatenate(data_per_streamline[key])

self._assert_dps(data_per_streamline, len(streamlines))
return streamlines, data_per_streamline

@property
def lengths(self):
Expand Down Expand Up @@ -160,6 +212,7 @@ class SFTDataAbstract(object):
all information necessary to treat with streamlines: the data itself and
_offset, _lengths, space attributes, etc.
"""

def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List = None,
Expand Down Expand Up @@ -256,17 +309,19 @@ def as_sft(self,
streamline_ids: Union[List[int], int, slice, None]
List of chosen ids. If None, use all streamlines.
"""
streamlines = self._get_streamlines_as_list(streamline_ids)
streamlines, dps = self._get_streamlines_as_list(streamline_ids)

sft = StatefulTractogram(streamlines, self.space_attributes,
self.space, self.origin)
self.space, self.origin,
data_per_streamline=dps)

return sft


class SFTData(SFTDataAbstract):
def __init__(self, streamlines: ArraySequence,
lengths_mm: List, connectivity_matrix: np.ndarray,
data_per_streamline: np.ndarray = None,
**kwargs):
"""
streamlines: ArraySequence or LazyStreamlinesGetter
Expand All @@ -279,6 +334,7 @@ def __init__(self, streamlines: ArraySequence,
self._lengths_mm = lengths_mm
self._connectivity_matrix = connectivity_matrix
self.is_lazy = False
self.data_per_streamline = data_per_streamline

def __len__(self):
return len(self.streamlines)
Expand Down Expand Up @@ -306,7 +362,7 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
Creating class instance from the hdf in cases where data is not
loaded yet. Non-lazy = loading the data here.
"""
streamlines = _load_all_streamlines_from_hdf(hdf_group)
streamlines, dps_dict = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']

Expand All @@ -318,7 +374,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
else:
connectivity_matrix = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
space_attributes, space, origin = _load_streamlines_attributes_from_hdf(
hdf_group)

# Return an instance of SubjectMRIData instantiated through __init__
# with this loaded data:
Expand All @@ -328,13 +385,18 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)
connectivity_labels=connectivity_labels,
data_per_streamline=dps_dict)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
return self.streamlines.__getitem__(streamline_ids)
dps_indexed = {}
for key, value in self.data_per_streamline.items():
dps_indexed[key] = value[streamline_ids]

return self.streamlines.__getitem__(streamline_ids), dps_indexed
else:
return self.streamlines
return self.streamlines, self.data_per_streamline


class LazySFTData(SFTDataAbstract):
Expand Down Expand Up @@ -368,7 +430,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):

@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_attributes_from_hdf(
space_attributes, space, origin = _load_streamlines_attributes_from_hdf(
hdf_group)

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
Expand All @@ -384,6 +446,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
streamlines = self.streamlines_getter.get_array_sequence(
streamlines, dps = self.streamlines_getter.get_array_sequence(
streamline_ids)
return streamlines
return streamlines, dps
49 changes: 32 additions & 17 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
dps_keys: List[str] = [],
step_size: float = None,
nb_points: int = None,
compress: float = None,
Expand All @@ -160,8 +159,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
List of subject names for each data set.
groups_config: dict
Information from json file loaded as a dict.
dps_keys: List[str]
List of keys to keep in data_per_streamline. Default: [].
step_size: float
Step size to resample streamlines. Default: None.
nb_points: int
Expand All @@ -186,7 +183,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.validation_subjs = validation_subjs
self.testing_subjs = testing_subjs
self.groups_config = groups_config
self.dps_keys = dps_keys
self.step_size = step_size
self.nb_points = nb_points
self.compress = compress
Expand Down Expand Up @@ -596,7 +592,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
"in the config_file. If all files are .trk, we can use "
"ref 'same' but if some files were .tck, we need a ref!"
"Hint: Create a volume group 'ref' in the config file.")
sft, lengths, connectivity_matrix, conn_info = (
sft, lengths, connectivity_matrix, conn_info, dps_keys = (
self._process_one_streamline_group(
subj_input_dir, group, subj_id, ref))

Expand All @@ -614,6 +610,8 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
streamlines_group.attrs['dimensions'] = d
streamlines_group.attrs['voxel_sizes'] = vs
streamlines_group.attrs['voxel_order'] = vo

# This streamline's group connectivity info
if connectivity_matrix is not None:
streamlines_group.attrs[
'connectivity_matrix_type'] = conn_info[0]
Expand All @@ -626,18 +624,17 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
streamlines_group.attrs['connectivity_nb_blocs'] = \
conn_info[1]

# DPP not managed yet!
if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')
logging.debug(" Including dps \"{}\" in the HDF5."
.format(dps_keys))

for dps_key in self.dps_keys:
if dps_key not in sft.data_per_streamline:
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])
# This streamline's group dps info
dps_group = streamlines_group.create_group('data_per_streamline')
for dps_key in dps_keys:
dps_group.create_dataset(
dps_key, data=sft.data_per_streamline[dps_key])

# Accessing private Dipy values, but necessary.
# We need to deconstruct the streamlines into arrays with
Expand Down Expand Up @@ -678,6 +675,11 @@ def _process_one_streamline_group(
The Euclidean length of each streamline
"""
tractograms = self.groups_config[group]['files']
dps_keys = []
if 'dps_keys' in self.groups_config[group]:
dps_keys = self.groups_config[group]['dps_keys']
if isinstance(dps_keys, str):
dps_keys = [dps_keys]

# Silencing SFT's logger if our logging is in DEBUG mode, because it
# typically produces a lot of outputs!
Expand All @@ -690,7 +692,7 @@ def _process_one_streamline_group(
tractograms = format_filelist(tractograms, self.enforce_files_presence,
folder=subj_dir)
for tractogram_file in tractograms:
sft = self._load_and_process_sft(tractogram_file, header)
sft = self._load_and_process_sft(tractogram_file, header, dps_keys)

if sft is not None:
# Compute euclidean lengths (rasmm space)
Expand Down Expand Up @@ -750,9 +752,9 @@ def _process_one_streamline_group(
conn_matrix = np.load(conn_file)
conn_matrix = conn_matrix > 0

return final_sft, output_lengths, conn_matrix, conn_info
return final_sft, output_lengths, conn_matrix, conn_info, dps_keys

def _load_and_process_sft(self, tractogram_file, header):
def _load_and_process_sft(self, tractogram_file, header, dps_keys):
# Check file extension
_, file_extension = os.path.splitext(str(tractogram_file))
if file_extension not in ['.trk', '.tck']:
Expand All @@ -773,6 +775,19 @@ def _load_and_process_sft(self, tractogram_file, header):
.format(os.path.basename(tractogram_file)))
sft = load_tractogram(str(tractogram_file), header)

# Check for required dps_keys
for dps_key in dps_keys:
if dps_key not in sft.data_per_streamline.keys():
raise ValueError("DPS key {} is not present in file {}. Only "
"found the following keys: {}"
.format(dps_key, tractogram_file,
list(sft.data_per_streamline.keys())))

# Remove non-required dps_keys
for dps_key in list(sft.data_per_streamline.keys()):
if dps_key not in dps_keys:
del sft.data_per_streamline[dps_key]

# Resample or compress streamlines
sft = resample_or_compress(sft, self.step_size,
self.nb_points,
Expand Down
3 changes: 0 additions & 3 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def add_hdf5_creation_args(p: ArgumentParser):
"(Final concatenated standardized volumes and \n"
"final concatenated resampled/compressed "
"streamlines.)")
p.add_argument('--dps_keys', type=str, nargs='+', default=[],
help="List of keys to keep in data_per_streamline. "
"Default: Empty.")


def add_streamline_processing_args(p: ArgumentParser):
Expand Down
19 changes: 19 additions & 0 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
for i in range(len(sft.streamlines)):
old_streamline = sft.streamlines[i]
old_dpp = sft.data_per_point[i]

# Note: This getter gets lists of numpy arrays of shape
# (n_features, n_streamlines) for some reason. This is why we need to
# transpose the all_dps arrays at the end. Not sure why this is
# happening as the arrays are stored within the PerArrayDict as
# numpy arrays of shape (n_streamlines, n_features).
old_dps = sft.data_per_streamline[i]

# Cut if at least min_nb_points
Expand All @@ -103,6 +109,19 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
all_dpp = _extend_dict(all_dpp, old_dpp)
all_dps = _extend_dict(all_dps, old_dps)

# Since _extend_dict appends many numpy arrays into a list,
# we need to merge them into a single array that can be fed
# to StatefulTractogram at data_per_streamlines.
#
# Note: at this point, all_dps is a dict of lists of numpy arrays
# of shape (n_features, n_streamlines). The StatefulTractogram
# expects a dict of numpy arrays of shape (n_streamlines, n_features).
# We need to concat along the second axis and transpose to get the
# correct shape.

for key in sft.data_per_streamline.keys():
all_dps[key] = np.concatenate(all_dps[key], axis=1).transpose()

new_sft = StatefulTractogram.from_sft(all_streamlines, sft,
data_per_point=all_dpp,
data_per_streamline=all_dps)
Expand Down
Loading

0 comments on commit 7f0f6d9

Please sign in to comment.