Skip to content

Commit

Permalink
Move std mask option to config file
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jan 15, 2024
1 parent ebfc9d3 commit 0d2870b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 71 deletions.
98 changes: 45 additions & 53 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
std_mask: str, step_size: float = None,
compress: float = None,
step_size: float = None, compress: float = None,
enforce_files_presence: bool = True,
save_intermediate: bool = False,
intermediate_folder: Path = None):
Expand All @@ -126,8 +125,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
List of subject names for each data set.
groups_config: dict
Information from json file loaded as a dict.
std_mask: str
Name of the standardization mask inside each subject's folder.
step_size: float
Step size to resample streamlines. Default: None.
compress: float
Expand All @@ -152,7 +149,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.compress = compress

# Optional
self.std_mask = std_mask # (could be None)
self.save_intermediate = save_intermediate
self.enforce_files_presence = enforce_files_presence
self.intermediate_folder = intermediate_folder
Expand Down Expand Up @@ -295,20 +291,11 @@ def _check_files_presence(self):
config_file_list = sum(nested_lookup('files', self.groups_config), [])
config_file_list += nested_lookup(
'connectivity_matrix', self.groups_config)
config_file_list += nested_lookup('std_mask', self.groups_config)

for subj_id in self.all_subjs:
subj_input_dir = Path(self.root_folder).joinpath(subj_id)

# Find subject's standardization mask
if self.std_mask is not None:
for sub_mask in self.std_mask:
sub_std_mask_file = subj_input_dir.joinpath(
sub_mask.replace('*', subj_id))
if not sub_std_mask_file.is_file():
raise FileNotFoundError(
"Standardization mask {} not found for subject {}!"
.format(sub_std_mask_file, subj_id))

# Find subject's files from group_config
for this_file in config_file_list:
this_file = this_file.replace('*', subj_id)
Expand Down Expand Up @@ -368,31 +355,14 @@ def _create_one_subj(self, subj_id, hdf_handle):

subj_hdf_group = hdf_handle.create_group(subj_id)

# Find subject's standardization mask
subj_std_mask_data = None
if self.std_mask is not None:
for sub_mask in self.std_mask:
sub_mask = sub_mask.replace('*', subj_id)
logging.info(" - Loading standardization mask {}"
.format(sub_mask))
sub_mask_file = subj_input_dir.joinpath(sub_mask)
sub_mask_img = nib.load(sub_mask_file)
sub_mask_data = np.asanyarray(sub_mask_img.dataobj) > 0
if subj_std_mask_data is None:
subj_std_mask_data = sub_mask_data
else:
subj_std_mask_data = np.logical_or(sub_mask_data,
subj_std_mask_data)

# Add the subj data based on groups in the json config file
ref = self._create_volume_groups(
subj_id, subj_input_dir, subj_std_mask_data, subj_hdf_group)
ref = self._create_volume_groups(subj_id, subj_input_dir,
subj_hdf_group)

self._create_streamline_groups(ref, subj_input_dir, subj_id,
subj_hdf_group)

def _create_volume_groups(self, subj_id, subj_input_dir,
subj_std_mask_data, subj_hdf_group):
def _create_volume_groups(self, subj_id, subj_input_dir, subj_hdf_group):
"""
Create the hdf5 groups for all volume groups in the config_file for a
given subject.
Expand All @@ -407,7 +377,7 @@ def _create_volume_groups(self, subj_id, subj_input_dir,

(group_data, group_affine,
group_header, group_res) = self._process_one_volume_group(
group, subj_id, subj_input_dir, subj_std_mask_data)
group, subj_id, subj_input_dir)
if ref_header is None:
ref_header = group_header
else:
Expand All @@ -431,8 +401,7 @@ def _create_volume_groups(self, subj_id, subj_input_dir,
return ref_header

def _process_one_volume_group(self, group: str, subj_id: str,
subj_input_path: Path,
subj_std_mask_data: np.ndarray = None):
subj_input_dir: Path):
"""
Processes each volume group from the json config file for a given
subject:
Expand All @@ -448,10 +417,8 @@ def _process_one_volume_group(self, group: str, subj_id: str,
Group name.
subj_id: str
The subject's id.
subj_input_path: Path
subj_input_dir: Path
Path where the files from file_list should be found.
subj_std_mask_data: np.ndarray of bools, optional
Binary mask that will be used for data standardization.
Returns
-------
Expand All @@ -460,32 +427,57 @@ def _process_one_volume_group(self, group: str, subj_id: str,
group_affine: np.ndarray
Affine for the group.
"""
standardization = self.groups_config[group]['standardization']
std_mask = None
std_option = 'none'
if 'standardization' in self.groups_config[group]:
std_option = self.groups_config[group]['standardization']
if 'std_mask' in self.groups_config[group]:
if std_option == 'none':
logging.warning("You provided a std_mask for volume group {}, "
"but std_option is 'none'. Skipping.")
else:
# Load subject's standardization mask. Can be a list of files.
std_masks = self.groups_config[group]['std_mask']
if isinstance(std_masks, str):
std_masks = [std_masks]

for sub_mask in std_masks:
sub_mask = sub_mask.replace('*', subj_id)
logging.info(" - Loading standardization mask {}"
.format(sub_mask))
sub_mask_file = subj_input_dir.joinpath(sub_mask)
sub_mask_img = nib.load(sub_mask_file)
sub_mask_data = np.asanyarray(sub_mask_img.dataobj) > 0
if std_mask is None:
std_mask = sub_mask_data
else:
std_mask = np.logical_or(sub_mask_data, std_mask)

file_list = self.groups_config[group]['files']

# First file will define data dimension and affine
file_name = file_list[0].replace('*', subj_id)
first_file = subj_input_path.joinpath(file_name)
first_file = subj_input_dir.joinpath(file_name)
logging.info(" - Processing file {}".format(file_name))
group_data, group_affine, group_res, group_header = load_file_to4d(
first_file)

if standardization == 'per_file':
if std_option == 'per_file':
logging.debug(' *Standardizing sub-data')
group_data = standardize_data(group_data, subj_std_mask_data,
group_data = standardize_data(group_data, std_mask,
independent=False)

# Other files must fit (data shape, affine, voxel size)
# It is not a promise that data has been correctly registered, but it
# is a minimal check.
for file_name in file_list[1:]:
file_name = file_name.replace('*', subj_id)
data = _load_and_verify_file(file_name, subj_input_path, group,
data = _load_and_verify_file(file_name, subj_input_dir, group,
group_affine, group_res)

if standardization == 'per_file':
if std_option == 'per_file':
logging.debug(' *Standardizing sub-data')
data = standardize_data(data, subj_std_mask_data,
data = standardize_data(data, std_mask,
independent=False)

# Append file data to hdf group.
Expand All @@ -497,15 +489,15 @@ def _process_one_volume_group(self, group: str, subj_id: str,
'Wrong dimensions?'.format(file_name, group))

# Standardize data (per channel) (if not done 'per_file' yet).
if standardization == 'independent':
if std_option == 'independent':
logging.debug(' *Standardizing data on each feature.')
group_data = standardize_data(group_data, subj_std_mask_data,
group_data = standardize_data(group_data, std_mask,
independent=True)
elif standardization == 'all':
elif std_option == 'all':
logging.debug(' *Standardizing data as a whole.')
group_data = standardize_data(group_data, subj_std_mask_data,
group_data = standardize_data(group_data, std_mask,
independent=False)
elif standardization not in ['none', 'per_file']:
elif std_option not in ['none', 'per_file']:
raise ValueError("standardization must be one of "
"['all', 'independent', 'per_file', 'none']")

Expand Down
14 changes: 1 addition & 13 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def add_hdf5_creation_args(p: ArgumentParser):
"-> https://dwi-ml.readthedocs.io/en/latest/"
"creating_hdf5.html")
p.add_argument('out_hdf5_file',
help="Path and name of the output hdf5 file.\n If "
help="Path and name of the output hdf5 file. \nIf "
"--save_intermediate is set, the intermediate files "
"will be saved in \nthe same location, in a folder "
"name based on date and hour of creation.\n"
Expand Down Expand Up @@ -79,18 +79,6 @@ def add_hdf5_creation_args(p: ArgumentParser):
"final concatenated resampled/compressed streamlines.)")


def add_mri_processing_args(p: ArgumentParser):
g = p.add_argument_group('Volumes processing options:')
g.add_argument(
'--std_mask', nargs='+', metavar='m',
help="Mask defining the voxels used for data standardization. \n"
"-> Should be the name of a file inside dwi_ml_ready/{subj_id}.\n"
"-> You may add wildcards (*) that will be replaced by the "
"subject's id. \n"
"-> If none is given, all non-zero voxels will be used.\n"
"-> If more than one are given, masks will be combined.")


def add_streamline_processing_args(p: ArgumentParser):
g = p.add_argument_group('Streamlines processing options:')
add_resample_or_compress_arg(g)
8 changes: 3 additions & 5 deletions scripts_python/dwiml_create_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@

from dwi_ml.data.hdf5.hdf5_creation import HDF5Creator
from dwi_ml.data.hdf5.utils import (
add_hdf5_creation_args, add_mri_processing_args,
add_streamline_processing_args)
add_hdf5_creation_args, add_streamline_processing_args)
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.io_utils import add_logging_arg

Expand Down Expand Up @@ -89,8 +88,8 @@ def prepare_hdf5_creator(args):
# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.std_mask, args.step_size,
args.compress, args.enforce_files_presence,
groups_config, args.step_size, args.compress,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

return creator
Expand All @@ -101,7 +100,6 @@ def _parse_args():
formatter_class=argparse.RawTextHelpFormatter)

add_hdf5_creation_args(p)
add_mri_processing_args(p)
add_streamline_processing_args(p)
add_overwrite_arg(p)
add_logging_arg(p)
Expand Down

0 comments on commit 0d2870b

Please sign in to comment.