Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to add noise on targets #203

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
split_ratio: float = 0.,
noise_gaussian_size_forward: float = 0.,
noise_gaussian_size_loss: float = 0.,
reverse_ratio: float = 0., log_level=logging.root.level):
"""
Parameters
Expand All @@ -83,8 +84,10 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
are using interface seeding, this is not necessary.
noise_gaussian_size_forward : float
DATA AUGMENTATION: Add random Gaussian noise to streamline
coordinates with given variance. This corresponds to the std of the
Gaussian. Value is given in voxel world. Noise is truncated to
coordinates with given variance. Noise is added AFTER
interpolation of underlying data.
This corresponds to the std of the Gaussian.
Value is given in voxel world. Noise is truncated to
+/- 2*noise_gaussian_size.
** Suggestion. Make sure that
2*(noise_gaussian_size) < step_size/2 (in vox)
Expand All @@ -93,6 +96,8 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
rewinds of step_size/2, but not further, so the direction of the
segment won't flip. Suggestion, you could choose ~0.1 * step-size.
Default = 0.
noise_gaussian_size_forward : float
Idem, for streamlines used as target (during training only).
reverse_ratio: float
DATA AUGMENTATION: If set, reversed a part of the streamlines in
the batch. You could want to reverse ALL your data and then use
Expand Down Expand Up @@ -120,7 +125,8 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
self.np_rng = np.random.RandomState(self.rng)

# Data augmentation for streamlines:
self.noise_gaussian_size_train = noise_gaussian_size_forward
self.noise_gaussian_size_forward = noise_gaussian_size_forward
self.noise_gaussian_size_loss = noise_gaussian_size_loss
self.split_ratio = split_ratio
self.reverse_ratio = reverse_ratio
if self.split_ratio and not 0 <= self.split_ratio <= 1:
Expand All @@ -133,7 +139,8 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
# For later use, context
self.context = None
self.context_subset = None
self.context_noise_size = None
self.context_noise_size_forward = None
self.context_noise_size_loss = None

@property
def params_for_checkpoint(self):
Expand All @@ -144,7 +151,8 @@ def params_for_checkpoint(self):
params = {
'streamline_group_name': self.streamline_group_name,
'rng': self.rng,
'noise_gaussian_size_forward': self.noise_gaussian_size_train,
'noise_gaussian_size_forward': self.noise_gaussian_size_forward,
'noise_gaussian_size_loss': self.noise_gaussian_size_loss,
'reverse_ratio': self.reverse_ratio,
'split_ratio': self.split_ratio,
}
Expand All @@ -154,10 +162,12 @@ def set_context(self, context: str):
if self.context != context:
if context == 'training':
self.context_subset = self.dataset.training_set
self.context_noise_size = self.noise_gaussian_size_train
self.context_noise_size_forward = self.noise_gaussian_size_forward
self.context_noise_size_loss = self.noise_gaussian_size_loss
elif context == 'validation':
self.context_subset = self.dataset.validation_set
self.context_noise_size = 0.
self.context_noise_size_forward = 0.
self.context_noise_size_loss = 0.
else:
raise ValueError("Context should be either 'training' or "
"'validation'.")
Expand Down Expand Up @@ -199,16 +209,30 @@ def _data_augmentation_sft(self, sft):

return sft

def add_noise_streamlines(self, batch_streamlines, device):
def add_noise_streamlines_forward(self, batch_streamlines, device):
# This method is called by the trainer only before the forward method.
# Targets are not modified for the loss computation.
# Adding noise to coordinates. Streamlines are in voxel space by now.
# Noise is considered in voxel space.
if (self.context_noise_size_forward is not None and
self.context_noise_size_forward > 0):
logger.debug(" Adding noise {}"
.format(self.context_noise_size_forward))
batch_streamlines = add_noise_to_tensor(
batch_streamlines, self.context_noise_size_forward, device)
return batch_streamlines

def add_noise_streamlines_loss(self, batch_streamlines, device):
# This method is called by the trainer only before the forward method.
# Targets are not modified for the loss computation.
# Adding noise to coordinates. Streamlines are in voxel space by now.
# Noise is considered in voxel space.
if self.context_noise_size is not None and self.context_noise_size > 0:
if (self.context_noise_size_loss is not None and
self.context_noise_size_loss > 0):
logger.debug(" Adding noise {}"
.format(self.context_noise_size))
.format(self.context_noise_size_loss))
batch_streamlines = add_noise_to_tensor(
batch_streamlines, self.context_noise_size, device)
batch_streamlines, self.context_noise_size_loss, device)
return batch_streamlines

def load_batch_streamlines(
Expand Down
5 changes: 4 additions & 1 deletion dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ def run_one_batch(self, data, average_results=True):
logger.debug('*** Computing forward propagation')
if self.model.forward_uses_streamlines:
# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines(
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
Expand All @@ -1129,6 +1129,9 @@ def run_one_batch(self, data, average_results=True):

logger.debug('*** Computing loss')
if self.model.loss_uses_streamlines:
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=average_results)
else:
Expand Down
12 changes: 10 additions & 2 deletions dwi_ml/training/utils/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ def add_args_batch_loader(p: argparse.ArgumentParser):
bl_g.add_argument(
'--noise_gaussian_size_forward', type=float, metavar='s', default=0.,
help="If set, add random Gaussian noise to streamline coordinates \n"
"with given variance. This corresponds to the std of the \n"
"Gaussian. [0]\n**Make sure noise is smaller than your step size "
"with given variance. Noise is added AFTER interpolation of "
"underlying data. \nExample of use: when concatenating previous "
"direction to input.\n"
"This corresponds to the std of the Gaussian. [0]\n"
"**Make sure noise is smaller than your step size "
"to avoid \nflipping direction! (We can't verify if --step_size "
"is not \nspecified here, but if it is, we limit noise to \n"
"+/- 0.5 * step-size.).\n"
"** We also limit noise to +/- 2 * noise_gaussian_size.\n"
"Suggestion: 0.1 * step-size.")
bl_g.add_argument(
'--noise_gaussian_size_loss', type=float, metavar='s', default=0.,
help='Idem, but loss is added to targets instead (during training '
'only).')
bl_g.add_argument(
'--split_ratio', type=float, metavar='r', default=0.,
help="Percentage of streamlines to randomly split into 2, in each \n"
Expand All @@ -44,6 +51,7 @@ def prepare_batch_loader(dataset, model, args, sub_loggers_level):
streamline_group_name=args.streamline_group_name,
# STREAMLINES AUGMENTATION
noise_gaussian_size_forward=args.noise_gaussian_size_forward,
noise_gaussian_size_loss=args.noise_gaussian_size_loss,
reverse_ratio=args.reverse_ratio, split_ratio=args.split_ratio,
# OTHER
rng=args.rng, log_level=sub_loggers_level)
Expand Down