Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Sep 27, 2023
1 parent ef1491b commit eb0337d
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 44 deletions.
51 changes: 25 additions & 26 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,22 +216,19 @@ def compute_loss(self, outputs: List[Tensor],
# Compute directions
target_dirs = compute_directions(target_streamlines)

# For compress_loss: remember raw target dirs.
line_dirs = None
# For compress_loss and weight_with_angle: remember raw target dirs.
# Also, do not average now, we will do our own averaging.
target_dirs_copy = None
tmp_average_results = average_results
if self.weight_loss_with_angle or self.compress_loss:
line_dirs = [t.detach().clone() for t in target_dirs]
target_dirs_copy = [t.detach().clone() for t in target_dirs]
tmp_average_results = False

# Modify directions based on child model requirements.
# Ex: Add eos label. Convert to classes. Etc.
target_dirs = self._prepare_dirs_for_loss(target_dirs)
lengths = [len(t) for t in target_dirs]

# For compress_loss and weight_with_angle, do not average now, we
# will do our own averaging.
tmp_average_results = average_results
if self.compress_loss or self.weight_loss_with_angle:
tmp_average_results = False

# Stack and compute loss based on child model's loss definition.
outputs, target_dirs = self.stack_batch(outputs, target_dirs)
loss, n = self._compute_loss(outputs, target_dirs, tmp_average_results)
Expand All @@ -243,12 +240,12 @@ def compute_loss(self, outputs: List[Tensor],
eos_loss = [line_loss[-1] for line_loss in loss]
loss = [line_loss[:-1] for line_loss in loss]
loss = weight_value_with_angle(
values=loss, streamlines=None, dirs=line_dirs)
values=loss, streamlines=None, dirs=target_dirs_copy)
for i in range(len(loss)):
loss[i] = torch.hstack((loss[i], eos_loss[i]))
else:
loss = weight_value_with_angle(
values=loss, streamlines=None, dirs=line_dirs)
values=loss, streamlines=None, dirs=target_dirs_copy)
if not self.compress_loss:
if average_results:
loss = torch.hstack(loss)
Expand All @@ -258,7 +255,7 @@ def compute_loss(self, outputs: List[Tensor],
if self.compress_loss:
loss = list(torch.split(loss, lengths))
final_loss, final_n = compress_streamline_values(
streamlines=None, dirs=line_dirs, values=loss,
streamlines=None, dirs=target_dirs_copy, values=loss,
compress_eps=self.compress_eps)
logging.info("Converted {} data points into {} compressed data "
"points".format(sum(lengths), final_n))
Expand Down Expand Up @@ -421,7 +418,7 @@ def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor,

n = 1
if average_results:
loss_dir, n = _mean_and_weight(loss_dirs)
loss_dirs, n = _mean_and_weight(loss_dirs)

# 2. EOS loss:
if self.add_eos:
Expand All @@ -433,7 +430,7 @@ def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor,
loss_eos = binary_cross_entropy_eos(learned_eos, target_eos,
average_results)

return loss_dirs + self.eos_weight * loss_eos, n
return loss_dirs + self.eos_weight * loss_eos, n
else:
return loss_dirs, n

Expand Down Expand Up @@ -809,14 +806,16 @@ class SingleGaussianDG(AbstractDirectionGetterModel):
===========> WE SUGGEST TO :
USE GRADIENT CLIPPING **AND** low learning rate.
"""
def __init__(self, normalize_targets: float = None, **kwargs):
def __init__(self, normalize_targets: float = None,
**kwargs):
# 3D gaussian supports compressed streamlines
super().__init__(key='gaussian',
supports_compressed_streamlines=True,
loss_description='negative log-likelihood',
**kwargs)

self.normalize_targets = normalize_targets
self.entropy_weight = 0.0 # Not tested yet.

# Layers
# 3 values as mean, 3 values as sigma
Expand Down Expand Up @@ -899,17 +898,18 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],

# Create an official function-probability distribution from the means
# and variances
np.set_printoptions(precision=3)
distribution = MultivariateNormal(
means, covariance_matrix=torch.diag_embed(sigmas ** 2))
nll_loss = -distribution.log_prob(target_dirs)
nll_loss = -distribution.log_prob(target_dirs[:, 0:3])

# Trying to ensure that sigma values are not too small.
# Else, normal values become very big, and gradients too.
entropy = distribution.entropy()
logging.info("Sigma: {}. Entropy: {}".format(torch.mean(sigmas),
torch.mean(entropy)))
#nll_loss = nll_loss - 0.1 * entropy
if self.entropy_weight > 0:
# Trying to ensure that sigma values are not too small.
# Entropy values range between 0 and log(K). 0 = high probability.
# We want a high entropy / low certainty = we will minimize -entropy.
entropy = distribution.entropy()
logging.info("Computing batch loss with sigma {}, entropy: {}"
.format(torch.mean(sigmas), torch.mean(entropy)))
nll_loss = nll_loss - self.entropy_weight * entropy

n = 1
if average_results:
Expand All @@ -918,8 +918,7 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
real_eos = target_dirs[:, -1]
loss_eos = binary_cross_entropy_eos(learned_eos, real_eos,
loss_eos = binary_cross_entropy_eos(learned_eos, target_dirs[:, -1],
average_results)
return nll_loss + self.eos_weight * loss_eos, n
else:
Expand Down Expand Up @@ -958,7 +957,7 @@ def _get_tracking_direction_det(self, learned_gaussian_params: Tensor,
"""
# Returns the direction of the max of the Gaussian = the mean.
means, sigmas = learned_gaussian_params
dirs = means[0:3]
dirs = means[:, 0:3]

if self.add_eos:
eos_prob = torch.sigmoid(means[:, -1])
Expand Down
4 changes: 2 additions & 2 deletions dwi_ml/models/utils/direction_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def check_args_direction_getter(args):
"Mises mixture. Ignored.")

# Regression and normalisation
if 'regression' or 'gaussian' in args.dg_key:
if 'regression' in args.dg_key or 'gaussian' in args.dg_key:
dg_args['normalize_targets'] = args.normalize_targets
else:
elif args.normalize_targets:
raise ValueError("--normalize_targets is only an option for "
"regression and gaussian models.")

Expand Down
20 changes: 17 additions & 3 deletions dwi_ml/testing/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
force_compress_loss: bool
If true, compresses the loss even if that is not the model's
parameter.
change_weight_with_angle: bool
weight_with_angle: bool
If true, modify model's wieght_loss_with_angle param.
"""
if uncompress_loss and force_compress_loss:
Expand All @@ -141,9 +141,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
batch_size = self.batch_size or len(sft)
nb_batches = int(np.ceil(len(sft) / batch_size))

if 'gaussian' in self.model.direction_getter.key:
outputs = ([], [])
elif 'fisher' in self.model.direction_getter.key:
raise NotImplementedError
else:
outputs = []

losses = []
compressed_n = []
outputs = []
batch_start = 0
batch_end = batch_size
with torch.no_grad():
Expand Down Expand Up @@ -180,7 +186,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
losses.extend([line_loss.cpu() for line_loss in
tmp_losses])

outputs.extend([o.cpu() for o in tmp_outputs])
# ToDo. See if we can simplify to fit with all models
if 'gaussian' in self.model.direction_getter.key:
tmp_means, tmp_sigmas = tmp_outputs
outputs[0].extend([m.cpu() for m in tmp_means])
outputs[1].extend([s.cpu() for s in tmp_sigmas])
elif 'fisher' in self.model.direction_getter.key:
raise NotImplementedError
else:
outputs.extend([o.cpu() for o in tmp_outputs])

batch_start = batch_end
batch_end = min(batch_start + batch_size, len(sft))
Expand Down
28 changes: 23 additions & 5 deletions dwi_ml/testing/visu_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def combine_displacement_with_ref(out_dirs, sft, step_size_mm=None):
color_x = []
color_y = []
color_z = []

for i, s in enumerate(sft.streamlines):
this_s_len = len(s)

Expand Down Expand Up @@ -253,20 +254,37 @@ def run_visu_save_colored_displacement(
save_tractogram(worst_sft, worst_sft_name)

# Save displacement
args.pick_idx = list(range(10))
if args.out_displacement_sft:
if args.out_colored_sft:
# We have run model on all streamlines. Picking a few now.
sft, idx = pick_a_few(
sft, best_idx, worst_idx, args.pick_at_random,
args.pick_best_and_worst, args.pick_idx)
outputs = [outputs[i] for i in idx]

# Either concat, run, split or (chosen:) loop
# ToDo. See if we can simplify to fit with all models
if 'gaussian' in model.direction_getter.key:
means, sigmas = outputs
means = [means[i] for i in idx]
lengths = [len(line) for line in means]
outputs = (torch.vstack(means),
torch.vstack([sigmas[i] for i in idx]))

elif 'fisher' in model.direction_getter.key:
raise NotImplementedError
else:
outputs = [outputs[i] for i in idx]
lengths = [len(line) for line in outputs]
outputs = torch.vstack(outputs)

# Use eos_thresh of 1 to be sure we don't output a NaN
with torch.no_grad():
out_dirs = [model.get_tracking_directions(
s_output, algo='det', eos_stopping_thresh=1.0).numpy()
for s_output in outputs]
out_dirs = model.get_tracking_directions(
outputs, algo='det', eos_stopping_thresh=1.0)

out_dirs = torch.split(out_dirs, lengths)

out_dirs = [o.numpy() for o in out_dirs]

# Save error together with ref
sft = combine_displacement_with_ref(out_dirs, sft, model.step_size)
Expand Down
8 changes: 3 additions & 5 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ def check_stopping_cause(checkpoint_state, new_patience=None,
class DWIMLTrainerOneInput(DWIMLAbstractTrainer):
batch_loader: DWIMLBatchLoaderOneInput

def run_one_batch(self, data, average_results=True):
def run_one_batch(self, data):
"""
Run a batch of data through the model (calling its forward method)
and return the mean loss. If training, run the backward method too.
Expand All @@ -1096,8 +1096,6 @@ def run_one_batch(self, data, average_results=True):
- final_streamline_ids_per_subj: the dict of streamlines ids from
the list of all streamlines (if we concatenate all sfts'
streamlines).
average_results: bool
If true, returns the averaged loss (as defined by the model).
Returns
-------
Expand Down Expand Up @@ -1153,10 +1151,10 @@ def run_one_batch(self, data, average_results=True):
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=average_results)
average_results=True)
else:
results = self.model.compute_loss(model_outputs,
average_results=average_results)
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)
Expand Down
4 changes: 3 additions & 1 deletion dwi_ml/unit_tests/test_train_trainerOneInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_trainer_and_models(experiments_path):
trainer.train_and_validate()

# Initializing model 2
logging.info("\n\n-----------TESTING TEST MODEL # 2: WITH PD ------------")
logging.info("\n\n-----------TESTING TEST MODEL # 2: WITH PD AND "
"DIRECTION GETTER ------------")
model2 = TrackingModelForTestWithPD()
batch_sampler, batch_loader = _create_sampler_and_loader(dataset, model)

Expand Down Expand Up @@ -85,4 +86,5 @@ def _create_trainer(batch_sampler, batch_loader, model, experiments_path,

if __name__ == '__main__':
tmp_dir = tempfile.TemporaryDirectory()
logging.getLogger().setLevel('INFO')
test_trainer_and_models(tmp_dir.name)
2 changes: 1 addition & 1 deletion scripts_python/dwiml_compute_loss_copy_previous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Printing the average loss function for a given dataset when we simply copy the
previous direction.
Target := SFT.streamlines's directions[1:]
Target := SFT.streamlines' directions[1:]
Y := Previous directions.
loss = DirectionGetter(Target, Y)
"""
Expand Down
1 change: 1 addition & 0 deletions scripts_python/tests/test_all_steps_learn2track.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_training(script_runner, experiments_path):
input_group_name, streamline_group_name,
'--max_epochs', '1', '--batch_size_training', '5',
'--batch_size_validation', '5',
'--dg_key', 'gaussian',
'--batch_size_units', 'nb_streamlines',
'--max_batches_per_epoch_training', '2',
'--max_batches_per_epoch_validation', '1',
Expand Down
1 change: 0 additions & 1 deletion scripts_python/tests/test_compute_loss_copy_previous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ def test_running(script_runner, experiments_path):
'--pick_at_random',
hdf5_file, subj_id, streamline_group_name)
assert ret.success

0 comments on commit eb0337d

Please sign in to comment.