Skip to content

Commit

Permalink
Add clip grad to transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed May 13, 2024
1 parent c6fd0a8 commit 46d2b27
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions dwi_ml/training/projects/learn2track_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

class Learn2TrackTrainer(DWIMLTrainerForTrackingOneInput):
"""
Trainer for Learn2Track. Nearly the same as in dwi_ml, but we add the
clip_grad parameter to avoid exploding gradients, typical in RNN.
Trainer for Learn2Track. Nearly the same as in parent class, but the
generation-validation phase (tracking) uses the hidden states.
"""
model: Learn2TrackModel

Expand Down
12 changes: 8 additions & 4 deletions dwi_ml/training/projects/transformer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def __init__(self, **kwargs):
"""
super().__init__(**kwargs)

def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj):
def propagate_multiple_lines(self, lines: List[torch.Tensor],
ids_per_subj):
assert self.model.step_size is not None, \
"We can't propagate compressed streamlines."

# Getting the first inputs
tmp_lines = [line[:-1, :] for line in lines]
batch_inputs = self.batch_loader.load_batch_inputs(tmp_lines, ids_per_subj)
batch_inputs = self.batch_loader.load_batch_inputs(tmp_lines,
ids_per_subj)
del tmp_lines

def update_memory_after_removing_lines(can_continue: np.ndarray, __):
Expand All @@ -53,7 +55,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):
final_lines = []
for subj_idx, line_idx in ids_per_subj.items():

with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle:
with h5py.File(self.batch_loader.dataset.hdf5_file, 'r'
) as hdf_handle:
subj_id = self.batch_loader.context_subset.subjects[subj_idx]
logging.debug("Loading subj {} ({})'s tracking mask."
.format(subj_idx, subj_id))
Expand All @@ -65,7 +68,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):
final_lines.extend(propagate_multiple_lines(
lines[line_idx], update_memory_after_removing_lines,
get_dirs_at_last_pos, theta=theta,
step_size=self.model.step_size, verify_opposite_direction=False,
step_size=self.model.step_size,
verify_opposite_direction=False,
mask=tracking_mask, max_nbr_pts=max_nbr_pts,
append_last_point=False, normalize_directions=True))

Expand Down
4 changes: 4 additions & 0 deletions dwi_ml/training/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def add_training_args(p: argparse.ArgumentParser,
'--max_batches_per_epoch_validation', type=int, default=1000,
metavar='n',
help="Maximum number of batches per epoch during validation.")
training_group.add_argument(
'--clip_grad', type=float, default=None,
help="Value to which the gradient norms to avoid exploding gradients."
"\nDefault = None (not clipping).")

if add_a_tracking_validation_phase:
training_group.add_argument(
Expand Down
9 changes: 1 addition & 8 deletions scripts_python/l2t_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,9 @@ def prepare_arg_parser():
add_mandatory_args_experiment_and_hdf5_path(p)
add_args_batch_sampler(p)
add_args_batch_loader(p)
training_group = add_training_args(p, add_a_tracking_validation_phase=True)
add_training_args(p, add_a_tracking_validation_phase=True)
add_memory_args(p, add_lazy_options=True, add_rng=True)
add_verbose_arg(p)

# Additional arg for projects
training_group.add_argument(
'--clip_grad', type=float, default=None,
help="Value to which the gradient norms to avoid exploding gradients."
"\nDefault = None (not clipping).")

add_model_args(p)

return p
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/tt_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def init_from_args(args, sub_loggers_level):
max_batches_per_epoch_training=args.max_batches_per_epoch_training,
max_batches_per_epoch_validation=args.max_batches_per_epoch_validation,
patience=args.patience, patience_delta=args.patience_delta,
from_checkpoint=False,
from_checkpoint=False, clip_grad=args.clip_grad,
# (generation validation:)
add_a_tracking_validation_phase=args.add_a_tracking_validation_phase,
tracking_phase_frequency=args.tracking_phase_frequency,
Expand Down

0 comments on commit 46d2b27

Please sign in to comment.