diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index d8c370ac..c72a89bc 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -8,9 +8,10 @@ import h5py from nibabel.streamlines import ArraySequence import numpy as np +from collections import defaultdict -def _load_space_attributes_from_hdf(hdf_group: h5py.Group): +def _load_streamlines_attributes_from_hdf(hdf_group: h5py.Group): a = np.array(hdf_group.attrs['affine']) d = np.array(hdf_group.attrs['dimensions']) vs = np.array(hdf_group.attrs['voxel_sizes']) @@ -43,7 +44,13 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group): streamlines._offsets = np.array(hdf_group['offsets']) streamlines._lengths = np.array(hdf_group['lengths']) - return streamlines + # DPS + hdf_dps_group = hdf_group['data_per_streamline'] + dps_dict = {} + for dps_key in hdf_dps_group.keys(): + dps_dict[dps_key] = hdf_dps_group[dps_key][:] + + return streamlines, dps_dict def _load_connectivity_info(hdf_group: h5py.Group): @@ -79,40 +86,85 @@ def _get_one_streamline(self, idx: int): return data + def _assert_dps(self, dps_dict, n_streamlines): + for key, value in dps_dict.items(): + if len(value) != n_streamlines: + raise ValueError( + f"Length of data_per_streamline {key} is {len(value)} " + f"but should be {n_streamlines}.") + elif not isinstance(value, np.ndarray): + raise ValueError( + f"Data_per_streamline {key} should be a numpy array, " + f"not a {type(value)}.") + def get_array_sequence(self, item=None): if item is None: - streamlines = _load_all_streamlines_from_hdf(self.hdf_group) + streamlines, data_per_streamline = _load_all_streamlines_from_hdf( + self.hdf_group) else: streamlines = ArraySequence() + data_per_streamline = defaultdict(list) + + # If data_per_streamline is not in the hdf5, use an empty dict + # so that we don't add anything to the data_per_streamline in the + # following steps. + hdf_dps_group = self.hdf_group['data_per_streamline'] if \ + 'data_per_streamline' in self.hdf_group.keys() else {} if isinstance(item, int): - streamlines.append(self._get_one_streamline(item)) + data = self._get_one_streamline(item) + streamlines.append(data) + + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) elif isinstance(item, list) or isinstance(item, np.ndarray): - # Getting a list of value from a hdf5: slow. Uses fancy indexing. - # But possible. See here: + # Getting a list of value from a hdf5: slow. Uses fancy + # indexing. But possible. See here: # https://stackoverflow.com/questions/21766145/h5py-correct-way-to-slice-array-datasets # Looping and accessing ourselves. # Good also load the whole data and access the indexes after. # toDo Test speed for the three options. for i in item: - streamlines.append(self._get_one_streamline(i), - cache_build=True) + data = self._get_one_streamline(i) + streamlines.append(data, cache_build=True) + + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) + streamlines.finalize_append() elif isinstance(item, slice): offsets = self.hdf_group['offsets'][item] lengths = self.hdf_group['lengths'][item] - for offset, length in zip(offsets, lengths): + indices = np.arange(item.start, item.stop, item.step) + for offset, length, idx in zip(offsets, lengths, indices): streamline = self.hdf_group['data'][offset:offset + length] streamlines.append(streamline, cache_build=True) + + for dps_key in hdf_dps_group.keys(): + # Indexing with a list (e.g. [idx]) will preserve the + # shape of the array. Crucial for concatenation below. + dps_data = hdf_dps_group[dps_key][[idx]] + data_per_streamline[dps_key].append(dps_data) streamlines.finalize_append() else: raise ValueError('Item should be either a int, list, ' 'np.ndarray or slice but we received {}' .format(type(item))) - return streamlines + + # The accumulated data_per_streamline is a list of numpy arrays. + # We need to merge them into a single numpy array so it can be + # reused in the StatefulTractogram. + for key in data_per_streamline.keys(): + data_per_streamline[key] = \ + np.concatenate(data_per_streamline[key]) + + self._assert_dps(data_per_streamline, len(streamlines)) + return streamlines, data_per_streamline @property def lengths(self): @@ -160,6 +212,7 @@ class SFTDataAbstract(object): all information necessary to treat with streamlines: the data itself and _offset, _lengths, space attributes, etc. """ + def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, contains_connectivity: bool, connectivity_nb_blocs: List = None, @@ -256,10 +309,11 @@ def as_sft(self, streamline_ids: Union[List[int], int, slice, None] List of chosen ids. If None, use all streamlines. """ - streamlines = self._get_streamlines_as_list(streamline_ids) + streamlines, dps = self._get_streamlines_as_list(streamline_ids) sft = StatefulTractogram(streamlines, self.space_attributes, - self.space, self.origin) + self.space, self.origin, + data_per_streamline=dps) return sft @@ -267,6 +321,7 @@ def as_sft(self, class SFTData(SFTDataAbstract): def __init__(self, streamlines: ArraySequence, lengths_mm: List, connectivity_matrix: np.ndarray, + data_per_streamline: np.ndarray = None, **kwargs): """ streamlines: ArraySequence or LazyStreamlinesGetter @@ -279,6 +334,7 @@ def __init__(self, streamlines: ArraySequence, self._lengths_mm = lengths_mm self._connectivity_matrix = connectivity_matrix self.is_lazy = False + self.data_per_streamline = data_per_streamline def __len__(self): return len(self.streamlines) @@ -306,7 +362,7 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. """ - streamlines = _load_all_streamlines_from_hdf(hdf_group) + streamlines, dps_dict = _load_all_streamlines_from_hdf(hdf_group) # Adding non-hidden parameters for nicer later access lengths_mm = hdf_group['euclidean_lengths'] @@ -318,7 +374,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): else: connectivity_matrix = None - space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group) + space_attributes, space, origin = _load_streamlines_attributes_from_hdf( + hdf_group) # Return an instance of SubjectMRIData instantiated through __init__ # with this loaded data: @@ -328,13 +385,18 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space=space, origin=origin, contains_connectivity=contains_connectivity, connectivity_nb_blocs=connectivity_nb_blocs, - connectivity_labels=connectivity_labels) + connectivity_labels=connectivity_labels, + data_per_streamline=dps_dict) def _get_streamlines_as_list(self, streamline_ids): if streamline_ids is not None: - return self.streamlines.__getitem__(streamline_ids) + dps_indexed = {} + for key, value in self.data_per_streamline.items(): + dps_indexed[key] = value[streamline_ids] + + return self.streamlines.__getitem__(streamline_ids), dps_indexed else: - return self.streamlines + return self.streamlines, self.data_per_streamline class LazySFTData(SFTDataAbstract): @@ -368,7 +430,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): @classmethod def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): - space_attributes, space, origin = _load_space_attributes_from_hdf( + space_attributes, space, origin = _load_streamlines_attributes_from_hdf( hdf_group) contains_connectivity, connectivity_nb_blocs, connectivity_labels = \ @@ -384,6 +446,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): connectivity_labels=connectivity_labels) def _get_streamlines_as_list(self, streamline_ids): - streamlines = self.streamlines_getter.get_array_sequence( + streamlines, dps = self.streamlines_getter.get_array_sequence( streamline_ids) - return streamlines + return streamlines, dps diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index cd789e42..006b16fe 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -136,7 +136,6 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, - dps_keys: List[str] = [], step_size: float = None, nb_points: int = None, compress: float = None, @@ -160,8 +159,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, List of subject names for each data set. groups_config: dict Information from json file loaded as a dict. - dps_keys: List[str] - List of keys to keep in data_per_streamline. Default: []. step_size: float Step size to resample streamlines. Default: None. nb_points: int @@ -186,7 +183,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.validation_subjs = validation_subjs self.testing_subjs = testing_subjs self.groups_config = groups_config - self.dps_keys = dps_keys self.step_size = step_size self.nb_points = nb_points self.compress = compress @@ -596,7 +592,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, "in the config_file. If all files are .trk, we can use " "ref 'same' but if some files were .tck, we need a ref!" "Hint: Create a volume group 'ref' in the config file.") - sft, lengths, connectivity_matrix, conn_info = ( + sft, lengths, connectivity_matrix, conn_info, dps_keys = ( self._process_one_streamline_group( subj_input_dir, group, subj_id, ref)) @@ -614,6 +610,8 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, streamlines_group.attrs['dimensions'] = d streamlines_group.attrs['voxel_sizes'] = vs streamlines_group.attrs['voxel_order'] = vo + + # This streamline's group connectivity info if connectivity_matrix is not None: streamlines_group.attrs[ 'connectivity_matrix_type'] = conn_info[0] @@ -626,18 +624,17 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, streamlines_group.attrs['connectivity_nb_blocs'] = \ conn_info[1] + # DPP not managed yet! if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') + logging.debug(" Including dps \"{}\" in the HDF5." + .format(dps_keys)) - for dps_key in self.dps_keys: - if dps_key not in sft.data_per_streamline: - raise ValueError( - "The data_per_streamline key '{}' was not found in " - "the sft. Check your tractogram file.".format(dps_key)) - - logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) - streamlines_group.create_dataset('dps_' + dps_key, - data=sft.data_per_streamline[dps_key]) + # This streamline's group dps info + dps_group = streamlines_group.create_group('data_per_streamline') + for dps_key in dps_keys: + dps_group.create_dataset( + dps_key, data=sft.data_per_streamline[dps_key]) # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with @@ -678,6 +675,11 @@ def _process_one_streamline_group( The Euclidean length of each streamline """ tractograms = self.groups_config[group]['files'] + dps_keys = [] + if 'dps_keys' in self.groups_config[group]: + dps_keys = self.groups_config[group]['dps_keys'] + if isinstance(dps_keys, str): + dps_keys = [dps_keys] # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs! @@ -690,7 +692,7 @@ def _process_one_streamline_group( tractograms = format_filelist(tractograms, self.enforce_files_presence, folder=subj_dir) for tractogram_file in tractograms: - sft = self._load_and_process_sft(tractogram_file, header) + sft = self._load_and_process_sft(tractogram_file, header, dps_keys) if sft is not None: # Compute euclidean lengths (rasmm space) @@ -750,9 +752,9 @@ def _process_one_streamline_group( conn_matrix = np.load(conn_file) conn_matrix = conn_matrix > 0 - return final_sft, output_lengths, conn_matrix, conn_info + return final_sft, output_lengths, conn_matrix, conn_info, dps_keys - def _load_and_process_sft(self, tractogram_file, header): + def _load_and_process_sft(self, tractogram_file, header, dps_keys): # Check file extension _, file_extension = os.path.splitext(str(tractogram_file)) if file_extension not in ['.trk', '.tck']: @@ -773,6 +775,19 @@ def _load_and_process_sft(self, tractogram_file, header): .format(os.path.basename(tractogram_file))) sft = load_tractogram(str(tractogram_file), header) + # Check for required dps_keys + for dps_key in dps_keys: + if dps_key not in sft.data_per_streamline.keys(): + raise ValueError("DPS key {} is not present in file {}. Only " + "found the following keys: {}" + .format(dps_key, tractogram_file, + list(sft.data_per_streamline.keys()))) + + # Remove non-required dps_keys + for dps_key in list(sft.data_per_streamline.keys()): + if dps_key not in dps_keys: + del sft.data_per_streamline[dps_key] + # Resample or compress streamlines sft = resample_or_compress(sft, self.step_size, self.nb_points, diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index 4d0aca12..d8a6d990 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -78,9 +78,6 @@ def add_hdf5_creation_args(p: ArgumentParser): "(Final concatenated standardized volumes and \n" "final concatenated resampled/compressed " "streamlines.)") - p.add_argument('--dps_keys', type=str, nargs='+', default=[], - help="List of keys to keep in data_per_streamline. " - "Default: Empty.") def add_streamline_processing_args(p: ArgumentParser): diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index cd46a185..3e0e828a 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -81,6 +81,12 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState, for i in range(len(sft.streamlines)): old_streamline = sft.streamlines[i] old_dpp = sft.data_per_point[i] + + # Note: This getter gets lists of numpy arrays of shape + # (n_features, n_streamlines) for some reason. This is why we need to + # transpose the all_dps arrays at the end. Not sure why this is + # happening as the arrays are stored within the PerArrayDict as + # numpy arrays of shape (n_streamlines, n_features). old_dps = sft.data_per_streamline[i] # Cut if at least min_nb_points @@ -103,6 +109,19 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState, all_dpp = _extend_dict(all_dpp, old_dpp) all_dps = _extend_dict(all_dps, old_dps) + # Since _extend_dict appends many numpy arrays into a list, + # we need to merge them into a single array that can be fed + # to StatefulTractogram at data_per_streamlines. + # + # Note: at this point, all_dps is a dict of lists of numpy arrays + # of shape (n_features, n_streamlines). The StatefulTractogram + # expects a dict of numpy arrays of shape (n_streamlines, n_features). + # We need to concat along the second axis and transpose to get the + # correct shape. + + for key in sft.data_per_streamline.keys(): + all_dps[key] = np.concatenate(all_dps[key], axis=1).transpose() + new_sft = StatefulTractogram.from_sft(all_streamlines, sft, data_per_point=all_dpp, data_per_streamline=all_dps) diff --git a/dwi_ml/unit_tests/test_dataset.py b/dwi_ml/unit_tests/test_dataset.py index 6cf53313..1fab9fb3 100755 --- a/dwi_ml/unit_tests/test_dataset.py +++ b/dwi_ml/unit_tests/test_dataset.py @@ -5,6 +5,7 @@ import h5py import torch +import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram from dwi_ml.data.dataset.multi_subject_containers import \ @@ -21,13 +22,15 @@ TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_NB_STREAMLINES, TEST_EXPECTED_MRI_SHAPE, TEST_EXPECTED_NB_SUBJECTS, TEST_EXPECTED_NB_FEATURES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ + fetch_testing_data +dps_key_1 = 'mean_color_dps' +dps_key_2 = 'mock_2d_dps' -def test_multisubjectdataset(): - data_dir = fetch_testing_data() - logging.debug("Unit test: previous dirs") +def test_multisubjectdataset(script_runner): + data_dir = fetch_testing_data() hdf5_filename = os.path.join(data_dir, 'hdf5_file.hdf5') @@ -88,6 +91,7 @@ def _verify_mri(mri_data, training_set, group_number): def _verify_sft_data(sft_data, group_number): expected_nb = TEST_EXPECTED_NB_STREAMLINES[group_number] assert len(sft_data.as_sft()) == expected_nb + expected_mock_2d_dps = np.random.RandomState(42).rand(expected_nb, 42) # First streamline's first coordinate: # Also verifying accessing by index @@ -96,10 +100,28 @@ def _verify_sft_data(sft_data, group_number): assert len(list_one) == 1 assert len(list_one.streamlines[0][0, :]) == 3 # a x, y, z coordinate + # Both dps should be in the data_per_streamline + # of the sft. Also making sure that the data is + # the same as expected. + assert dps_key_1 in list_one.data_per_streamline.keys() + assert dps_key_2 in list_one.data_per_streamline.keys() + assert np.allclose( + list_one.data_per_streamline[dps_key_2][0], + expected_mock_2d_dps[0]) + # Assessing by slice list_4 = sft_data.as_sft(slice(0, 4)) assert len(list_4) == 4 + # Same as above, but with slices. Both dps + # should be in the data_per_streamline and + # the data should be the same as expected. + assert dps_key_1 in list_4.data_per_streamline.keys() + assert dps_key_2 in list_4.data_per_streamline.keys() + assert np.allclose( + list_4.data_per_streamline[dps_key_2][0:4], + expected_mock_2d_dps[0:4]) + def _non_lazy_version(hdf5_filename): logging.debug("-------------- NON-LAZY version -----------------") diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index f1bcf6c0..eb15b5b6 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -31,7 +31,7 @@ def fetch_testing_data(): # Access to the file dwi_ml.zip: # https://drive.google.com/uc?id=1beRWAorhaINCncttgwqVAP2rNOfx842Q name_as_dict = { - 'data_for_tests_dwi_ml.zip': "59c9275d2fe83b7e2d6154877ab32b8b"} + 'data_for_tests_dwi_ml.zip': "f8bd3bd88e10d939a7168468e1e99a00"} fetch_data(name_as_dict) return testing_data_dir diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 2719449c..03294f8b 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -88,7 +88,6 @@ def prepare_hdf5_creator(args): creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, groups_config, - args.dps_keys, args.step_size, args.nb_points, args.compress_th, diff --git a/scripts_python/tests/test_create_hdf5_dataset.py b/scripts_python/tests/test_create_hdf5_dataset.py index fd2ade91..2106767a 100644 --- a/scripts_python/tests/test_create_hdf5_dataset.py +++ b/scripts_python/tests/test_create_hdf5_dataset.py @@ -4,18 +4,38 @@ import os import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ + fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() +# Note. Our test config file is: +# { +# "input": { +# "type": "volume", +# "files": ["anat/t1.nii.gz", "dwi/fa.nii.gz"], +# "standardization": "per_file", +# "std_mask": ["masks/wm.nii.gz"] +# }, +# "wm_mask": { +# "type": "volume", +# "files": ["masks/wm.nii.gz"], +# "standardization": "none" +# }, +# "streamlines": { +# "type": "streamlines", +# "files": ["example_bundle/Fornix.trk"], +# "dps_keys": ['mean_color_dps', 'mock_2d_dps'] +# } +# } def test_help_option(script_runner): ret = script_runner.run('dwiml_create_hdf5_dataset.py', '--help') assert ret.success -def test_execution_bst(script_runner): +def test_execution(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) dwi_ml_folder = os.path.join(data_dir, 'dwi_ml_ready') diff --git a/source/2_A_creating_the_hdf5.rst b/source/2_A_creating_the_hdf5.rst index e7a96156..0d4e704a 100644 --- a/source/2_A_creating_the_hdf5.rst +++ b/source/2_A_creating_the_hdf5.rst @@ -83,7 +83,8 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w "files": ["tractograms/bundle1.trk", "tractograms/wholebrain.trk", "tractograms/*__wholebrain.trk"], ----> Will get, for instance, sub1000__bundle1.trk "connectivity_matrix": "my_file.npy", "connectivity_nb_blocs": 6 ---> OR - "connectivity_labels": labels_volume_group + "connectivity_labels": labels_volume_group, + "dps_keys": ['dps1', 'dps2'] } "bad_streamlines": { "type": "streamlines", @@ -138,6 +139,8 @@ Additional attributes for streamlines groups: - **connectivity_nb_blocs**: This explains that the connectivity matrix was created by dividing the volume space into regular blocs. See dwiml_compute_connectivity_matrix_from_blocs for a description. The value should be either an integers or a list of three integers. - **connectivity_labels**: This explains that the connectivity matrix was created by dividing the cortex into a list of regions associated with labels. The value must be the name of the associated labels file (typically a nifti file filled with integers). + - **dps_keys**: List of data_per_streamline keys to keep in memory in the hdf5. + 2.4. Creating the hdf5 ********************** diff --git a/source/2_B_advanced_hdf5_organization.rst b/source/2_B_advanced_hdf5_organization.rst index 56a8dbbc..3793a2a4 100644 --- a/source/2_B_advanced_hdf5_organization.rst +++ b/source/2_B_advanced_hdf5_organization.rst @@ -31,8 +31,9 @@ Here is the output format created by dwiml_create_hdf5_dataset.py and recognized # (others:) hdf5['subj1']['group1']['connectivity_matrix'] hdf5['subj1']['group1']['connectivity_matrix_type'] = 'from_blocs' or 'from_labels' - hdf5['subj1']['group1']['connectivity_label_volume'] (the labels' volume group) OR + hdf5['subj1']['group1']['connectivity_label_volume'] (the labels\' volume group) OR hdf5['subj1']['group1']['connectivity_nb_blocs'] (a list of three integers) + hdf5['subj1']['group1']['data_per_streamline'] (a HDF5 group of 2D numpy arrays) # For volumes, other available data: hdf5['sub1']['group1']['affine']