Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dps management #248

Merged
merged 11 commits into from
Nov 7, 2024
102 changes: 82 additions & 20 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import h5py
from nibabel.streamlines import ArraySequence
import numpy as np
from collections import defaultdict


def _load_space_attributes_from_hdf(hdf_group: h5py.Group):
def _load_streamlines_attributes_from_hdf(hdf_group: h5py.Group):
a = np.array(hdf_group.attrs['affine'])
d = np.array(hdf_group.attrs['dimensions'])
vs = np.array(hdf_group.attrs['voxel_sizes'])
Expand Down Expand Up @@ -43,7 +44,13 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
streamlines._offsets = np.array(hdf_group['offsets'])
streamlines._lengths = np.array(hdf_group['lengths'])

return streamlines
# DPS
hdf_dps_group = hdf_group['data_per_streamline']
dps_dict = {}
for dps_key in hdf_dps_group.keys():
dps_dict[dps_key] = hdf_dps_group[dps_key][:]

return streamlines, dps_dict


def _load_connectivity_info(hdf_group: h5py.Group):
Expand Down Expand Up @@ -79,40 +86,85 @@ def _get_one_streamline(self, idx: int):

return data

def _assert_dps(self, dps_dict, n_streamlines):
for key, value in dps_dict.items():
if len(value) != n_streamlines:
raise ValueError(
f"Length of data_per_streamline {key} is {len(value)} "
f"but should be {n_streamlines}.")
elif not isinstance(value, np.ndarray):
raise ValueError(
f"Data_per_streamline {key} should be a numpy array, "
f"not a {type(value)}.")

def get_array_sequence(self, item=None):
if item is None:
streamlines = _load_all_streamlines_from_hdf(self.hdf_group)
streamlines, data_per_streamline = _load_all_streamlines_from_hdf(
self.hdf_group)
else:
streamlines = ArraySequence()
data_per_streamline = defaultdict(list)

# If data_per_streamline is not in the hdf5, use an empty dict
# so that we don't add anything to the data_per_streamline in the
# following steps.
hdf_dps_group = self.hdf_group['data_per_streamline'] if \
'data_per_streamline' in self.hdf_group.keys() else {}

if isinstance(item, int):
streamlines.append(self._get_one_streamline(item))
data = self._get_one_streamline(item)
streamlines.append(data)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

elif isinstance(item, list) or isinstance(item, np.ndarray):
# Getting a list of value from a hdf5: slow. Uses fancy indexing.
# But possible. See here:
# Getting a list of value from a hdf5: slow. Uses fancy
# indexing. But possible. See here:
# https://stackoverflow.com/questions/21766145/h5py-correct-way-to-slice-array-datasets
# Looping and accessing ourselves.
# Good also load the whole data and access the indexes after.
# toDo Test speed for the three options.
for i in item:
streamlines.append(self._get_one_streamline(i),
cache_build=True)
data = self._get_one_streamline(i)
streamlines.append(data, cache_build=True)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

streamlines.finalize_append()

elif isinstance(item, slice):
offsets = self.hdf_group['offsets'][item]
lengths = self.hdf_group['lengths'][item]
for offset, length in zip(offsets, lengths):
indices = np.arange(item.start, item.stop, item.step)
for offset, length, idx in zip(offsets, lengths, indices):
streamline = self.hdf_group['data'][offset:offset + length]
streamlines.append(streamline, cache_build=True)

for dps_key in hdf_dps_group.keys():
# Indexing with a list (e.g. [idx]) will preserve the
# shape of the array. Crucial for concatenation below.
dps_data = hdf_dps_group[dps_key][[idx]]
data_per_streamline[dps_key].append(dps_data)
streamlines.finalize_append()

else:
raise ValueError('Item should be either a int, list, '
'np.ndarray or slice but we received {}'
.format(type(item)))
return streamlines

# The accumulated data_per_streamline is a list of numpy arrays.
# We need to merge them into a single numpy array so it can be
# reused in the StatefulTractogram.
for key in data_per_streamline.keys():
data_per_streamline[key] = \
np.concatenate(data_per_streamline[key])

self._assert_dps(data_per_streamline, len(streamlines))
return streamlines, data_per_streamline

@property
def lengths(self):
Expand Down Expand Up @@ -160,6 +212,7 @@ class SFTDataAbstract(object):
all information necessary to treat with streamlines: the data itself and
_offset, _lengths, space attributes, etc.
"""

def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List = None,
Expand Down Expand Up @@ -256,17 +309,19 @@ def as_sft(self,
streamline_ids: Union[List[int], int, slice, None]
List of chosen ids. If None, use all streamlines.
"""
streamlines = self._get_streamlines_as_list(streamline_ids)
streamlines, dps = self._get_streamlines_as_list(streamline_ids)

sft = StatefulTractogram(streamlines, self.space_attributes,
self.space, self.origin)
self.space, self.origin,
data_per_streamline=dps)

return sft


class SFTData(SFTDataAbstract):
def __init__(self, streamlines: ArraySequence,
lengths_mm: List, connectivity_matrix: np.ndarray,
data_per_streamline: np.ndarray = None,
**kwargs):
"""
streamlines: ArraySequence or LazyStreamlinesGetter
Expand All @@ -279,6 +334,7 @@ def __init__(self, streamlines: ArraySequence,
self._lengths_mm = lengths_mm
self._connectivity_matrix = connectivity_matrix
self.is_lazy = False
self.data_per_streamline = data_per_streamline

def __len__(self):
return len(self.streamlines)
Expand Down Expand Up @@ -306,7 +362,7 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
Creating class instance from the hdf in cases where data is not
loaded yet. Non-lazy = loading the data here.
"""
streamlines = _load_all_streamlines_from_hdf(hdf_group)
streamlines, dps_dict = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']

Expand All @@ -318,7 +374,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
else:
connectivity_matrix = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
space_attributes, space, origin = _load_streamlines_attributes_from_hdf(
hdf_group)

# Return an instance of SubjectMRIData instantiated through __init__
# with this loaded data:
Expand All @@ -328,13 +385,18 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)
connectivity_labels=connectivity_labels,
data_per_streamline=dps_dict)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
return self.streamlines.__getitem__(streamline_ids)
dps_indexed = {}
for key, value in self.data_per_streamline.items():
dps_indexed[key] = value[streamline_ids]

return self.streamlines.__getitem__(streamline_ids), dps_indexed
else:
return self.streamlines
return self.streamlines, self.data_per_streamline


class LazySFTData(SFTDataAbstract):
Expand Down Expand Up @@ -368,7 +430,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):

@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_attributes_from_hdf(
space_attributes, space, origin = _load_streamlines_attributes_from_hdf(
hdf_group)

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
Expand All @@ -384,6 +446,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
streamlines = self.streamlines_getter.get_array_sequence(
streamlines, dps = self.streamlines_getter.get_array_sequence(
streamline_ids)
return streamlines
return streamlines, dps
49 changes: 32 additions & 17 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
dps_keys: List[str] = [],
step_size: float = None,
nb_points: int = None,
compress: float = None,
Expand All @@ -160,8 +159,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
List of subject names for each data set.
groups_config: dict
Information from json file loaded as a dict.
dps_keys: List[str]
List of keys to keep in data_per_streamline. Default: [].
step_size: float
Step size to resample streamlines. Default: None.
nb_points: int
Expand All @@ -186,7 +183,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.validation_subjs = validation_subjs
self.testing_subjs = testing_subjs
self.groups_config = groups_config
self.dps_keys = dps_keys
self.step_size = step_size
self.nb_points = nb_points
self.compress = compress
Expand Down Expand Up @@ -596,7 +592,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
"in the config_file. If all files are .trk, we can use "
"ref 'same' but if some files were .tck, we need a ref!"
"Hint: Create a volume group 'ref' in the config file.")
sft, lengths, connectivity_matrix, conn_info = (
sft, lengths, connectivity_matrix, conn_info, dps_keys = (
self._process_one_streamline_group(
subj_input_dir, group, subj_id, ref))

Expand All @@ -614,6 +610,8 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
streamlines_group.attrs['dimensions'] = d
streamlines_group.attrs['voxel_sizes'] = vs
streamlines_group.attrs['voxel_order'] = vo

# This streamline's group connectivity info
if connectivity_matrix is not None:
streamlines_group.attrs[
'connectivity_matrix_type'] = conn_info[0]
Expand All @@ -626,18 +624,17 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
streamlines_group.attrs['connectivity_nb_blocs'] = \
conn_info[1]

# DPP not managed yet!
if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')
logging.debug(" Including dps \"{}\" in the HDF5."
.format(dps_keys))

for dps_key in self.dps_keys:
if dps_key not in sft.data_per_streamline:
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])
# This streamline's group dps info
dps_group = streamlines_group.create_group('data_per_streamline')
for dps_key in dps_keys:
dps_group.create_dataset(
dps_key, data=sft.data_per_streamline[dps_key])

# Accessing private Dipy values, but necessary.
# We need to deconstruct the streamlines into arrays with
Expand Down Expand Up @@ -678,6 +675,11 @@ def _process_one_streamline_group(
The Euclidean length of each streamline
"""
tractograms = self.groups_config[group]['files']
dps_keys = []
if 'dps_keys' in self.groups_config[group]:
dps_keys = self.groups_config[group]['dps_keys']
if isinstance(dps_keys, str):
dps_keys = [dps_keys]

# Silencing SFT's logger if our logging is in DEBUG mode, because it
# typically produces a lot of outputs!
Expand All @@ -690,7 +692,7 @@ def _process_one_streamline_group(
tractograms = format_filelist(tractograms, self.enforce_files_presence,
folder=subj_dir)
for tractogram_file in tractograms:
sft = self._load_and_process_sft(tractogram_file, header)
sft = self._load_and_process_sft(tractogram_file, header, dps_keys)

if sft is not None:
# Compute euclidean lengths (rasmm space)
Expand Down Expand Up @@ -750,9 +752,9 @@ def _process_one_streamline_group(
conn_matrix = np.load(conn_file)
conn_matrix = conn_matrix > 0

return final_sft, output_lengths, conn_matrix, conn_info
return final_sft, output_lengths, conn_matrix, conn_info, dps_keys

def _load_and_process_sft(self, tractogram_file, header):
def _load_and_process_sft(self, tractogram_file, header, dps_keys):
# Check file extension
_, file_extension = os.path.splitext(str(tractogram_file))
if file_extension not in ['.trk', '.tck']:
Expand All @@ -773,6 +775,19 @@ def _load_and_process_sft(self, tractogram_file, header):
.format(os.path.basename(tractogram_file)))
sft = load_tractogram(str(tractogram_file), header)

# Check for required dps_keys
for dps_key in dps_keys:
if dps_key not in sft.data_per_streamline.keys():
raise ValueError("DPS key {} is not present in file {}. Only "
"found the following keys: {}"
.format(dps_key, tractogram_file,
list(sft.data_per_streamline.keys())))

# Remove non-required dps_keys
for dps_key in list(sft.data_per_streamline.keys()):
if dps_key not in dps_keys:
del sft.data_per_streamline[dps_key]

# Resample or compress streamlines
sft = resample_or_compress(sft, self.step_size,
self.nb_points,
Expand Down
3 changes: 0 additions & 3 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def add_hdf5_creation_args(p: ArgumentParser):
"(Final concatenated standardized volumes and \n"
"final concatenated resampled/compressed "
"streamlines.)")
p.add_argument('--dps_keys', type=str, nargs='+', default=[],
help="List of keys to keep in data_per_streamline. "
"Default: Empty.")


def add_streamline_processing_args(p: ArgumentParser):
Expand Down
19 changes: 19 additions & 0 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
for i in range(len(sft.streamlines)):
old_streamline = sft.streamlines[i]
old_dpp = sft.data_per_point[i]

# Note: This getter gets lists of numpy arrays of shape
# (n_features, n_streamlines) for some reason. This is why we need to
# transpose the all_dps arrays at the end. Not sure why this is
# happening as the arrays are stored within the PerArrayDict as
# numpy arrays of shape (n_streamlines, n_features).
old_dps = sft.data_per_streamline[i]

# Cut if at least min_nb_points
Expand All @@ -103,6 +109,19 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
all_dpp = _extend_dict(all_dpp, old_dpp)
all_dps = _extend_dict(all_dps, old_dps)

# Since _extend_dict appends many numpy arrays into a list,
# we need to merge them into a single array that can be fed
# to StatefulTractogram at data_per_streamlines.
#
# Note: at this point, all_dps is a dict of lists of numpy arrays
# of shape (n_features, n_streamlines). The StatefulTractogram
# expects a dict of numpy arrays of shape (n_streamlines, n_features).
# We need to concat along the second axis and transpose to get the
# correct shape.

for key in sft.data_per_streamline.keys():
all_dps[key] = np.concatenate(all_dps[key], axis=1).transpose()

new_sft = StatefulTractogram.from_sft(all_streamlines, sft,
data_per_point=all_dpp,
data_per_streamline=all_dps)
Expand Down
Loading
Loading