From 550ae33905b08783cdf36fe66b9b8f0bbf26df2f Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 25 Sep 2024 12:09:22 -0400 Subject: [PATCH 1/6] add the possibility to resample with nb points per streamline --- dwi_ml/data/hdf5/hdf5_creation.py | 43 +++++++++++-------- dwi_ml/data/hdf5/utils.py | 13 +++--- .../streamlines/data_augmentation.py | 25 ++++++++--- dwi_ml/io_utils.py | 11 +++-- scripts_python/dwiml_create_hdf5_dataset.py | 5 ++- 5 files changed, 64 insertions(+), 33 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 74715f1e..82414c52 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -136,7 +136,10 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, - step_size: float = None, compress: float = None, + step_size: float = None, + nb_points: int = None, + compress: float = None, + remove_invalid: bool = False, enforce_files_presence: bool = True, save_intermediate: bool = False, intermediate_folder: Path = None): @@ -156,8 +159,12 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, Information from json file loaded as a dict. step_size: float Step size to resample streamlines. Default: None. + nb_points: int + Number of points per streamline. Default: None. compress: float Compress streamlines. Default: None. + remove_invalid: bool + Remove invalid streamline. Default: False enforce_files_presence: bool If true, will stop if some files are not available for a subject. Default: True. @@ -175,7 +182,9 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.testing_subjs = testing_subjs self.groups_config = groups_config self.step_size = step_size + self.nb_points = nb_points self.compress = compress + self.remove_invalid = remove_invalid # Optional self.save_intermediate = save_intermediate @@ -359,6 +368,8 @@ def create_database(self): hdf_handle.attrs['testing_subjs'] = self.testing_subjs hdf_handle.attrs['step_size'] = self.step_size if \ self.step_size is not None else 'Not defined by user' + hdf_handle.attrs['nb_points'] = self.nb_points if \ + self.nb_points is not None else 'Not defined by user' hdf_handle.attrs['compress'] = self.compress if \ self.compress is not None else 'Not defined by user' @@ -632,6 +643,8 @@ def _process_one_streamline_group( Reference used to load and send the streamlines in voxel space and to create final merged SFT. If the file is a .trk, 'same' is used instead. + remove_invalid : bool + If True, invalid streamlines will be removed Returns ------- @@ -641,11 +654,10 @@ def _process_one_streamline_group( The Euclidean length of each streamline """ tractograms = self.groups_config[group]['files'] - if self.step_size and self.compress: raise ValueError( "Only one option can be chosen: either resampling to " - "step_size or compressing, not both.") + "step_size, nb_points or compressing, not both.") # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs! @@ -679,19 +691,12 @@ def _process_one_streamline_group( if self.save_intermediate: output_fname = self.intermediate_folder.joinpath( subj_id + '_' + group + '.trk') - logging.debug(' *Saving intermediate streamline group {} ' - 'into {}.'.format(group, output_fname)) + logging.debug(" *Saving intermediate streamline group {} " + "into {}.".format(group, output_fname)) # Note. Do not remove the str below. Does not work well # with Path. save_tractogram(final_sft, str(output_fname)) - # Removing invalid streamlines - logging.debug(' *Total: {:,.0f} streamlines. Now removing ' - 'invalid streamlines.'.format(len(final_sft))) - final_sft.remove_invalid_streamlines() - logging.info(" Final number of streamlines: {:,.0f}." - .format(len(final_sft))) - conn_matrix = None conn_info = None if 'connectivity_matrix' in self.groups_config[group]: @@ -735,10 +740,11 @@ def _load_and_process_sft(self, tractogram_file, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible with " - "volume groups\n ({})" - .format(tractogram_file)) + if header: + if not is_header_compatible(str(tractogram_file), header): + raise ValueError("Streamlines group is not compatible " + "with volume groups\n ({})" + .format(tractogram_file)) # overriding given header. header = 'same' @@ -748,6 +754,9 @@ def _load_and_process_sft(self, tractogram_file, header): sft = load_tractogram(str(tractogram_file), header) # Resample or compress streamlines - sft = resample_or_compress(sft, self.step_size, self.compress) + sft = resample_or_compress(sft, self.step_size, + self.nb_points, + self.compress, + self.remove_invalid) return sft diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index cbd7704b..d8a6d990 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -59,11 +59,11 @@ def add_hdf5_creation_args(p: ArgumentParser): help="A txt file containing the list of subjects ids to " "use for training. \n(Can be an empty file.)") p.add_argument('validation_subjs', - help="A txt file containing the list of subjects ids to use " - "for validation. \n(Can be an empty file.)") + help="A txt file containing the list of subjects ids" + " to use for validation. \n(Can be an empty file.)") p.add_argument('testing_subjs', - help="A txt file containing the list of subjects ids to use " - "for testing. \n(Can be an empty file.)") + help="A txt file containing the list of subjects ids" + " to use for testing. \n(Can be an empty file.)") # Optional arguments p.add_argument('--enforce_files_presence', type=bool, default=True, @@ -76,9 +76,10 @@ def add_hdf5_creation_args(p: ArgumentParser): "each subject inside the \nhdf5 folder, in sub-" "folders named subjid_intermediate.\n" "(Final concatenated standardized volumes and \n" - "final concatenated resampled/compressed streamlines.)") + "final concatenated resampled/compressed " + "streamlines.)") def add_streamline_processing_args(p: ArgumentParser): - g = p.add_argument_group('Streamlines processing options:') + g = p.add_argument_group('Streamlines processing options') add_resample_or_compress_arg(g) diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index 48683def..18428cb9 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -9,18 +9,33 @@ import numpy as np from scilpy.tractograms.streamline_operations import \ - resample_streamlines_step_size, compress_sft + resample_streamlines_num_points, resample_streamlines_step_size, \ + compress_sft def resample_or_compress(sft, step_size_mm: float = None, - compress: float = None): + nb_points: int = None, + compress: float = None, + remove_invalid: bool = False): if step_size_mm is not None: # Note. No matter the chosen space, resampling is done in mm. - logging.debug(" Resampling: {}".format(step_size_mm)) + logging.debug(" Resampling (step size): {}mm".format(step_size_mm)) sft = resample_streamlines_step_size(sft, step_size=step_size_mm) - if compress is not None: - logging.debug(" Compressing: {}".format(compress)) + elif nb_points is not None: + logging.debug(" Resampling: " + + "{} points per streamline".format(nb_points)) + sft = resample_streamlines_num_points(sft, nb_points) + elif compress is not None: + logging.debug(" Compressing: {}".format(compress)) sft = compress_sft(sft, compress) + + if remove_invalid: + logging.debug(" Total: {:,.0f} streamlines. Now removing " + "invalid streamlines.".format(len(sft))) + sft.remove_invalid_streamlines() + logging.info(" Final number of streamlines: {:,.0f}." + .format(len(sft))) + return sft diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py index b16e5baf..12558acf 100644 --- a/dwi_ml/io_utils.py +++ b/dwi_ml/io_utils.py @@ -2,20 +2,23 @@ import os from argparse import ArgumentParser -from scilpy.io.utils import add_processes_arg - def add_resample_or_compress_arg(p: ArgumentParser): + p.add_argument("--remove_invalid", action='store_true', + help="If set, remove invalid streamlines.") g = p.add_mutually_exclusive_group() g.add_argument( '--step_size', type=float, metavar='s', help="Step size to resample the data (in mm). Default: None") + g.add_argument('--nb_points', type=int, metavar='n', + help='Number of points per streamline in the output.' + 'Default: None') g.add_argument( '--compress', type=float, metavar='r', const=0.01, nargs='?', dest='compress_th', help="Compression ratio. Default: None. Default if set: 0.01.\n" - "If neither step_size nor compress are chosen, streamlines " - "will be kept \nas they are.") + "If neither step_size, nb_points nor compress " + "are chosen, \nstreamlines will be kept as they are.") def add_arg_existing_experiment_path(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 7266cd15..f0a82f9a 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -87,7 +87,10 @@ def prepare_hdf5_creator(args): # Instantiate a creator and perform checks creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, - groups_config, args.step_size, args.compress_th, + groups_config, args.step_size, + args.nb_points, + args.compress_th, + args.remove_invalid, args.enforce_files_presence, args.save_intermediate, intermediate_subdir) From f7c9bb237a072f4286812e5474c8db842a7983fb Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 25 Sep 2024 15:24:06 -0400 Subject: [PATCH 2/6] add dps --- dwi_ml/data/hdf5/hdf5_creation.py | 12 +++++++++--- dwi_ml/data/hdf5/utils.py | 3 +++ scripts_python/dwiml_create_hdf5_dataset.py | 4 +++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 82414c52..e8f3c663 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -136,6 +136,7 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, + dps_keys: List[str] = [], step_size: float = None, nb_points: int = None, compress: float = None, @@ -157,6 +158,8 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, List of subject names for each data set. groups_config: dict Information from json file loaded as a dict. + dps_keys: List[str] + List of keys to keep in data_per_streamline. Default: None. step_size: float Step size to resample streamlines. Default: None. nb_points: int @@ -181,6 +184,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.validation_subjs = validation_subjs self.testing_subjs = testing_subjs self.groups_config = groups_config + self.dps_keys = dps_keys self.step_size = step_size self.nb_points = nb_points self.compress = compress @@ -609,9 +613,11 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') - if len(sft.data_per_streamline) > 0: - logging.debug('sft contained data_per_streamlines. Data not ' - 'kept.') + + for dps_key in self.dps_keys: + logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) + streamlines_group.create_dataset('dps_' + dps_key, + data=sft.data_per_streamline[dps_key]) # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index d8a6d990..4d0aca12 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -78,6 +78,9 @@ def add_hdf5_creation_args(p: ArgumentParser): "(Final concatenated standardized volumes and \n" "final concatenated resampled/compressed " "streamlines.)") + p.add_argument('--dps_keys', type=str, nargs='+', default=[], + help="List of keys to keep in data_per_streamline. " + "Default: Empty.") def add_streamline_processing_args(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index f0a82f9a..2719449c 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -87,7 +87,9 @@ def prepare_hdf5_creator(args): # Instantiate a creator and perform checks creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, - groups_config, args.step_size, + groups_config, + args.dps_keys, + args.step_size, args.nb_points, args.compress_th, args.remove_invalid, From 0c98a7196f4ab00b81c9bf3c42f04f40290aacfc Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 25 Sep 2024 15:44:43 -0400 Subject: [PATCH 3/6] fix tests --- .github/workflows/test_package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml index 84d54cb5..c33f268c 100644 --- a/.github/workflows/test_package.yml +++ b/.github/workflows/test_package.yml @@ -25,6 +25,7 @@ jobs: - name: Install dependencies run: | + export SETUPTOOLS_USE_DISTUTILS=stdlib pip install --upgrade pip pip install pytest pip install -e . From 1a913d77bc040d6703982e5551c758681c203f5b Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 25 Sep 2024 16:15:43 -0400 Subject: [PATCH 4/6] fix tests --- dwi_ml/models/main_models.py | 13 ++++++++----- dwi_ml/testing/testers.py | 1 + dwi_ml/training/batch_loaders.py | 1 + .../dwiml_visualize_noise_on_streamlines.py | 4 +++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 596f9e0f..cf3af707 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -17,8 +17,7 @@ prepare_neighborhood_vectors from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.io_utils import add_resample_or_compress_arg -from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \ - AbstractDirectionGetterModel +from dwi_ml.models.direction_getter_models import keys_to_direction_getters from dwi_ml.models.embeddings import (keys_to_embeddings, NNEmbedding, NoEmbedding) from dwi_ml.models.utils.direction_getters import add_direction_getter_args @@ -37,6 +36,7 @@ class MainModelAbstract(torch.nn.Module): def __init__(self, experiment_name: str, # Target preprocessing params for the batch loader + tracker step_size: float = None, + nb_points: int = None, compress_lines: float = False, # Other log_level=logging.root.level): @@ -74,17 +74,20 @@ def __init__(self, experiment_name: str, # To tell our batch loader how to resample streamlines during training # (should also be the step size during tractography). - if step_size and compress_lines: - raise ValueError("You may choose either resampling or compressing," - "but not both.") + if (step_size and compress_lines) or (step_size and nb_points) or (nb_points and compress_lines): + raise ValueError("You may choose either resampling (step_size or nb_points)" + " or compressing, but not two of them or more.") elif step_size and step_size <= 0: raise ValueError("Step size can't be 0 or less!") + elif nb_points and nb_points <= 0: + raise ValueError("Number of points can't be 0 or less!") # Note. When using # scilpy.tracking.tools.resample_streamlines_step_size, a warning # is shown if step_size < 0.1 or > np.max(sft.voxel_sizes), saying # that the value is suspicious. Not raising the same warnings here # as you may be wanting to test weird things to understand better # your model. + self.nb_points = nb_points self.step_size = step_size self.compress_lines = compress_lines diff --git a/dwi_ml/testing/testers.py b/dwi_ml/testing/testers.py index 4664273d..4a353afd 100644 --- a/dwi_ml/testing/testers.py +++ b/dwi_ml/testing/testers.py @@ -109,6 +109,7 @@ def run_model_on_sft(self, sft, compute_loss=False): The mean eos error per line. """ sft = resample_or_compress(sft, self.model.step_size, + self.model.nb_points, self.model.compress_lines) sft.to_vox() sft.to_corner() diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 2152468b..743830bb 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -199,6 +199,7 @@ def _data_augmentation_sft(self, sft): "the hdf5 dataset. Not compressing again.") else: sft = resample_or_compress(sft, self.model.step_size, + self.model.nb_points, self.model.compress_lines) # Splitting streamlines diff --git a/scripts_python/dwiml_visualize_noise_on_streamlines.py b/scripts_python/dwiml_visualize_noise_on_streamlines.py index c07120ec..6a9d15f5 100644 --- a/scripts_python/dwiml_visualize_noise_on_streamlines.py +++ b/scripts_python/dwiml_visualize_noise_on_streamlines.py @@ -68,7 +68,9 @@ def main(): subj_sft_data = subj_data.sft_data_list[streamline_group_idx] sft = subj_sft_data.as_sft() - sft = resample_or_compress(sft, args.step_size, args.compress_th) + sft = resample_or_compress(sft, args.step_size, + args.nb_points, + args.compress_th) sft.to_vox() sft.to_corner() From 559668020ba2f491cc9a283d1db9116fef8dbece Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 26 Sep 2024 12:50:56 -0400 Subject: [PATCH 5/6] answer em comments --- dwi_ml/data/hdf5/hdf5_creation.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index e8f3c663..cad6b6bf 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -145,6 +145,8 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, save_intermediate: bool = False, intermediate_folder: Path = None): """ + Params step_size, nb_points and compress are mutually exclusive. + Params ------ root_folder: Path @@ -159,7 +161,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, groups_config: dict Information from json file loaded as a dict. dps_keys: List[str] - List of keys to keep in data_per_streamline. Default: None. + List of keys to keep in data_per_streamline. Default: []. step_size: float Step size to resample streamlines. Default: None. nb_points: int @@ -201,7 +203,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self._analyse_config_file() # -------- Performing checks - + self._check_streamlines_operations() # Check that all subjects exist. logging.debug("Preparing hdf5 creator for \n" " training subjs {}, \n" @@ -353,6 +355,19 @@ def flatten_list(a_list): self.enforce_files_presence, folder=subj_input_dir) + def _check_streamlines_operations(self): + valid = True + if self.step_size and self.nb_points: + valid = False + elif self.step_size and self.compress: + valid = False + elif self.nb_points and self.compress: + valid = False + if not valid: + raise ValueError( + "Only one option can be chosen: either resampling to " + "step_size, nb_points or compressing, not both.") + def create_database(self): """ Generates a hdf5 dataset from a group of subjects. Hdf5 dataset will @@ -746,11 +761,11 @@ def _load_and_process_sft(self, tractogram_file, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': - if header: - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible " - "with volume groups\n ({})" - .format(tractogram_file)) + if header and not is_header_compatible(str(tractogram_file), + header): + raise ValueError("Streamlines group is not compatible " + "with volume groups\n ({})" + .format(tractogram_file)) # overriding given header. header = 'same' From f94453f8c950b9128dc5ca3e99f4b215c8fe1ec1 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Fri, 27 Sep 2024 12:08:48 -0400 Subject: [PATCH 6/6] second round answer Em comments --- dwi_ml/data/hdf5/hdf5_creation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index cad6b6bf..3221d007 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -630,6 +630,11 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, logging.debug('sft contained data_per_point. Data not kept.') for dps_key in self.dps_keys: + if dps_key not in sft.data_per_streamline: + raise ValueError( + "The data_per_streamline key '{}' was not found in " + "the sft. Check your tractogram file.".format(dps_key)) + logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) streamlines_group.create_dataset('dps_' + dps_key, data=sft.data_per_streamline[dps_key]) @@ -675,10 +680,6 @@ def _process_one_streamline_group( The Euclidean length of each streamline """ tractograms = self.groups_config[group]['files'] - if self.step_size and self.compress: - raise ValueError( - "Only one option can be chosen: either resampling to " - "step_size, nb_points or compressing, not both.") # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs!