diff --git a/dwi_ml/training/projects/learn2track_trainer.py b/dwi_ml/training/projects/learn2track_trainer.py index ee7275b9..4f10474e 100644 --- a/dwi_ml/training/projects/learn2track_trainer.py +++ b/dwi_ml/training/projects/learn2track_trainer.py @@ -17,8 +17,8 @@ class Learn2TrackTrainer(DWIMLTrainerForTrackingOneInput): """ - Trainer for Learn2Track. Nearly the same as in dwi_ml, but we add the - clip_grad parameter to avoid exploding gradients, typical in RNN. + Trainer for Learn2Track. Nearly the same as in parent class, but the + generation-validation phase (tracking) uses the hidden states. """ model: Learn2TrackModel diff --git a/dwi_ml/training/projects/transformer_trainer.py b/dwi_ml/training/projects/transformer_trainer.py index 93754657..62f40ed4 100644 --- a/dwi_ml/training/projects/transformer_trainer.py +++ b/dwi_ml/training/projects/transformer_trainer.py @@ -20,13 +20,15 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) - def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): + def propagate_multiple_lines(self, lines: List[torch.Tensor], + ids_per_subj): assert self.model.step_size is not None, \ "We can't propagate compressed streamlines." # Getting the first inputs tmp_lines = [line[:-1, :] for line in lines] - batch_inputs = self.batch_loader.load_batch_inputs(tmp_lines, ids_per_subj) + batch_inputs = self.batch_loader.load_batch_inputs(tmp_lines, + ids_per_subj) del tmp_lines def update_memory_after_removing_lines(can_continue: np.ndarray, __): @@ -53,7 +55,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): final_lines = [] for subj_idx, line_idx in ids_per_subj.items(): - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r' + ) as hdf_handle: subj_id = self.batch_loader.context_subset.subjects[subj_idx] logging.debug("Loading subj {} ({})'s tracking mask." .format(subj_idx, subj_id)) @@ -65,7 +68,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): final_lines.extend(propagate_multiple_lines( lines[line_idx], update_memory_after_removing_lines, get_dirs_at_last_pos, theta=theta, - step_size=self.model.step_size, verify_opposite_direction=False, + step_size=self.model.step_size, + verify_opposite_direction=False, mask=tracking_mask, max_nbr_pts=max_nbr_pts, append_last_point=False, normalize_directions=True)) diff --git a/dwi_ml/training/utils/trainer.py b/dwi_ml/training/utils/trainer.py index b9c9ce89..074535f4 100644 --- a/dwi_ml/training/utils/trainer.py +++ b/dwi_ml/training/utils/trainer.py @@ -46,6 +46,10 @@ def add_training_args(p: argparse.ArgumentParser, '--max_batches_per_epoch_validation', type=int, default=1000, metavar='n', help="Maximum number of batches per epoch during validation.") + training_group.add_argument( + '--clip_grad', type=float, default=None, + help="Value to which the gradient norms to avoid exploding gradients." + "\nDefault = None (not clipping).") if add_a_tracking_validation_phase: training_group.add_argument( diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index 126bd9c3..5e360e62 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -40,16 +40,9 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - training_group = add_training_args(p, add_a_tracking_validation_phase=True) + add_training_args(p, add_a_tracking_validation_phase=True) add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) - - # Additional arg for projects - training_group.add_argument( - '--clip_grad', type=float, default=None, - help="Value to which the gradient norms to avoid exploding gradients." - "\nDefault = None (not clipping).") - add_model_args(p) return p diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py index 47f7f653..914474a5 100755 --- a/scripts_python/tt_train_model.py +++ b/scripts_python/tt_train_model.py @@ -143,7 +143,7 @@ def init_from_args(args, sub_loggers_level): max_batches_per_epoch_training=args.max_batches_per_epoch_training, max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, patience=args.patience, patience_delta=args.patience_delta, - from_checkpoint=False, + from_checkpoint=False, clip_grad=args.clip_grad, # (generation validation:) add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, tracking_phase_frequency=args.tracking_phase_frequency,