diff --git a/dwi_ml/models/direction_getter_models.py b/dwi_ml/models/direction_getter_models.py index a35ca73e..a7841d91 100644 --- a/dwi_ml/models/direction_getter_models.py +++ b/dwi_ml/models/direction_getter_models.py @@ -216,22 +216,19 @@ def compute_loss(self, outputs: List[Tensor], # Compute directions target_dirs = compute_directions(target_streamlines) - # For compress_loss: remember raw target dirs. - line_dirs = None + # For compress_loss and weight_with_angle: remember raw target dirs. + # Also, do not average now, we will do our own averaging. + target_dirs_copy = None + tmp_average_results = average_results if self.weight_loss_with_angle or self.compress_loss: - line_dirs = [t.detach().clone() for t in target_dirs] + target_dirs_copy = [t.detach().clone() for t in target_dirs] + tmp_average_results = False # Modify directions based on child model requirements. # Ex: Add eos label. Convert to classes. Etc. target_dirs = self._prepare_dirs_for_loss(target_dirs) lengths = [len(t) for t in target_dirs] - # For compress_loss and weight_with_angle, do not average now, we - # will do our own averaging. - tmp_average_results = average_results - if self.compress_loss or self.weight_loss_with_angle: - tmp_average_results = False - # Stack and compute loss based on child model's loss definition. outputs, target_dirs = self.stack_batch(outputs, target_dirs) loss, n = self._compute_loss(outputs, target_dirs, tmp_average_results) @@ -243,12 +240,12 @@ def compute_loss(self, outputs: List[Tensor], eos_loss = [line_loss[-1] for line_loss in loss] loss = [line_loss[:-1] for line_loss in loss] loss = weight_value_with_angle( - values=loss, streamlines=None, dirs=line_dirs) + values=loss, streamlines=None, dirs=target_dirs_copy) for i in range(len(loss)): loss[i] = torch.hstack((loss[i], eos_loss[i])) else: loss = weight_value_with_angle( - values=loss, streamlines=None, dirs=line_dirs) + values=loss, streamlines=None, dirs=target_dirs_copy) if not self.compress_loss: if average_results: loss = torch.hstack(loss) @@ -258,7 +255,7 @@ def compute_loss(self, outputs: List[Tensor], if self.compress_loss: loss = list(torch.split(loss, lengths)) final_loss, final_n = compress_streamline_values( - streamlines=None, dirs=line_dirs, values=loss, + streamlines=None, dirs=target_dirs_copy, values=loss, compress_eps=self.compress_eps) logging.info("Converted {} data points into {} compressed data " "points".format(sum(lengths), final_n)) @@ -421,7 +418,7 @@ def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor, n = 1 if average_results: - loss_dir, n = _mean_and_weight(loss_dirs) + loss_dirs, n = _mean_and_weight(loss_dirs) # 2. EOS loss: if self.add_eos: @@ -433,7 +430,7 @@ def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor, loss_eos = binary_cross_entropy_eos(learned_eos, target_eos, average_results) - return loss_dirs + self.eos_weight * loss_eos, n + return loss_dirs + self.eos_weight * loss_eos, n else: return loss_dirs, n @@ -809,7 +806,8 @@ class SingleGaussianDG(AbstractDirectionGetterModel): ===========> WE SUGGEST TO : USE GRADIENT CLIPPING **AND** low learning rate. """ - def __init__(self, normalize_targets: float = None, **kwargs): + def __init__(self, normalize_targets: float = None, + **kwargs): # 3D gaussian supports compressed streamlines super().__init__(key='gaussian', supports_compressed_streamlines=True, @@ -817,6 +815,7 @@ def __init__(self, normalize_targets: float = None, **kwargs): **kwargs) self.normalize_targets = normalize_targets + self.entropy_weight = 0.0 # Not tested yet. # Layers # 3 values as mean, 3 values as sigma @@ -899,17 +898,18 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor], # Create an official function-probability distribution from the means # and variances - np.set_printoptions(precision=3) distribution = MultivariateNormal( means, covariance_matrix=torch.diag_embed(sigmas ** 2)) - nll_loss = -distribution.log_prob(target_dirs) + nll_loss = -distribution.log_prob(target_dirs[:, 0:3]) - # Trying to ensure that sigma values are not too small. - # Else, normal values become very big, and gradients too. - entropy = distribution.entropy() - logging.info("Sigma: {}. Entropy: {}".format(torch.mean(sigmas), - torch.mean(entropy))) - #nll_loss = nll_loss - 0.1 * entropy + if self.entropy_weight > 0: + # Trying to ensure that sigma values are not too small. + # Entropy values range between 0 and log(K). 0 = high probability. + # We want a high entropy / low certainty = we will minimize -entropy. + entropy = distribution.entropy() + logging.info("Computing batch loss with sigma {}, entropy: {}" + .format(torch.mean(sigmas), torch.mean(entropy))) + nll_loss = nll_loss - self.entropy_weight * entropy n = 1 if average_results: @@ -918,8 +918,7 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor], # 2. EOS loss: if self.add_eos: # Binary cross-entropy - real_eos = target_dirs[:, -1] - loss_eos = binary_cross_entropy_eos(learned_eos, real_eos, + loss_eos = binary_cross_entropy_eos(learned_eos, target_dirs[:, -1], average_results) return nll_loss + self.eos_weight * loss_eos, n else: @@ -958,7 +957,7 @@ def _get_tracking_direction_det(self, learned_gaussian_params: Tensor, """ # Returns the direction of the max of the Gaussian = the mean. means, sigmas = learned_gaussian_params - dirs = means[0:3] + dirs = means[:, 0:3] if self.add_eos: eos_prob = torch.sigmoid(means[:, -1]) diff --git a/dwi_ml/models/utils/direction_getters.py b/dwi_ml/models/utils/direction_getters.py index 73889aef..3ce65a3c 100644 --- a/dwi_ml/models/utils/direction_getters.py +++ b/dwi_ml/models/utils/direction_getters.py @@ -101,9 +101,9 @@ def check_args_direction_getter(args): "Mises mixture. Ignored.") # Regression and normalisation - if 'regression' or 'gaussian' in args.dg_key: + if 'regression' in args.dg_key or 'gaussian' in args.dg_key: dg_args['normalize_targets'] = args.normalize_targets - else: + elif args.normalize_targets: raise ValueError("--normalize_targets is only an option for " "regression and gaussian models.") diff --git a/dwi_ml/testing/testers.py b/dwi_ml/testing/testers.py index 5d3d37a7..0af90fc6 100644 --- a/dwi_ml/testing/testers.py +++ b/dwi_ml/testing/testers.py @@ -119,7 +119,7 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True, force_compress_loss: bool If true, compresses the loss even if that is not the model's parameter. - change_weight_with_angle: bool + weight_with_angle: bool If true, modify model's wieght_loss_with_angle param. """ if uncompress_loss and force_compress_loss: @@ -141,9 +141,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True, batch_size = self.batch_size or len(sft) nb_batches = int(np.ceil(len(sft) / batch_size)) + if 'gaussian' in self.model.direction_getter.key: + outputs = ([], []) + elif 'fisher' in self.model.direction_getter.key: + raise NotImplementedError + else: + outputs = [] + losses = [] compressed_n = [] - outputs = [] batch_start = 0 batch_end = batch_size with torch.no_grad(): @@ -180,7 +186,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True, losses.extend([line_loss.cpu() for line_loss in tmp_losses]) - outputs.extend([o.cpu() for o in tmp_outputs]) + # ToDo. See if we can simplify to fit with all models + if 'gaussian' in self.model.direction_getter.key: + tmp_means, tmp_sigmas = tmp_outputs + outputs[0].extend([m.cpu() for m in tmp_means]) + outputs[1].extend([s.cpu() for s in tmp_sigmas]) + elif 'fisher' in self.model.direction_getter.key: + raise NotImplementedError + else: + outputs.extend([o.cpu() for o in tmp_outputs]) batch_start = batch_end batch_end = min(batch_start + batch_size, len(sft)) diff --git a/dwi_ml/testing/visu_loss.py b/dwi_ml/testing/visu_loss.py index 18b420ff..680d22b1 100644 --- a/dwi_ml/testing/visu_loss.py +++ b/dwi_ml/testing/visu_loss.py @@ -161,6 +161,7 @@ def combine_displacement_with_ref(out_dirs, sft, step_size_mm=None): color_x = [] color_y = [] color_z = [] + for i, s in enumerate(sft.streamlines): this_s_len = len(s) @@ -253,20 +254,37 @@ def run_visu_save_colored_displacement( save_tractogram(worst_sft, worst_sft_name) # Save displacement + args.pick_idx = list(range(10)) if args.out_displacement_sft: if args.out_colored_sft: # We have run model on all streamlines. Picking a few now. sft, idx = pick_a_few( sft, best_idx, worst_idx, args.pick_at_random, args.pick_best_and_worst, args.pick_idx) - outputs = [outputs[i] for i in idx] - # Either concat, run, split or (chosen:) loop + # ToDo. See if we can simplify to fit with all models + if 'gaussian' in model.direction_getter.key: + means, sigmas = outputs + means = [means[i] for i in idx] + lengths = [len(line) for line in means] + outputs = (torch.vstack(means), + torch.vstack([sigmas[i] for i in idx])) + + elif 'fisher' in model.direction_getter.key: + raise NotImplementedError + else: + outputs = [outputs[i] for i in idx] + lengths = [len(line) for line in outputs] + outputs = torch.vstack(outputs) + # Use eos_thresh of 1 to be sure we don't output a NaN with torch.no_grad(): - out_dirs = [model.get_tracking_directions( - s_output, algo='det', eos_stopping_thresh=1.0).numpy() - for s_output in outputs] + out_dirs = model.get_tracking_directions( + outputs, algo='det', eos_stopping_thresh=1.0) + + out_dirs = torch.split(out_dirs, lengths) + + out_dirs = [o.numpy() for o in out_dirs] # Save error together with ref sft = combine_displacement_with_ref(out_dirs, sft, model.step_size) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 471f2a0f..17aa9e91 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1081,7 +1081,7 @@ def check_stopping_cause(checkpoint_state, new_patience=None, class DWIMLTrainerOneInput(DWIMLAbstractTrainer): batch_loader: DWIMLBatchLoaderOneInput - def run_one_batch(self, data, average_results=True): + def run_one_batch(self, data): """ Run a batch of data through the model (calling its forward method) and return the mean loss. If training, run the backward method too. @@ -1096,8 +1096,6 @@ def run_one_batch(self, data, average_results=True): - final_streamline_ids_per_subj: the dict of streamlines ids from the list of all streamlines (if we concatenate all sfts' streamlines). - average_results: bool - If true, returns the averaged loss (as defined by the model). Returns ------- @@ -1153,10 +1151,10 @@ def run_one_batch(self, data, average_results=True): targets, self.device) results = self.model.compute_loss(model_outputs, targets, - average_results=average_results) + average_results=True) else: results = self.model.compute_loss(model_outputs, - average_results=average_results) + average_results=True) if self.use_gpu: log_gpu_memory_usage(logger) diff --git a/dwi_ml/unit_tests/test_train_trainerOneInput.py b/dwi_ml/unit_tests/test_train_trainerOneInput.py index 5af3f7b3..0a11a105 100644 --- a/dwi_ml/unit_tests/test_train_trainerOneInput.py +++ b/dwi_ml/unit_tests/test_train_trainerOneInput.py @@ -43,7 +43,8 @@ def test_trainer_and_models(experiments_path): trainer.train_and_validate() # Initializing model 2 - logging.info("\n\n-----------TESTING TEST MODEL # 2: WITH PD ------------") + logging.info("\n\n-----------TESTING TEST MODEL # 2: WITH PD AND " + "DIRECTION GETTER ------------") model2 = TrackingModelForTestWithPD() batch_sampler, batch_loader = _create_sampler_and_loader(dataset, model) @@ -85,4 +86,5 @@ def _create_trainer(batch_sampler, batch_loader, model, experiments_path, if __name__ == '__main__': tmp_dir = tempfile.TemporaryDirectory() + logging.getLogger().setLevel('INFO') test_trainer_and_models(tmp_dir.name) diff --git a/scripts_python/dwiml_compute_loss_copy_previous.py b/scripts_python/dwiml_compute_loss_copy_previous.py index 4715b67e..ac094f3f 100644 --- a/scripts_python/dwiml_compute_loss_copy_previous.py +++ b/scripts_python/dwiml_compute_loss_copy_previous.py @@ -7,7 +7,7 @@ Printing the average loss function for a given dataset when we simply copy the previous direction. - Target := SFT.streamlines's directions[1:] + Target := SFT.streamlines' directions[1:] Y := Previous directions. loss = DirectionGetter(Target, Y) """ diff --git a/scripts_python/tests/test_all_steps_learn2track.py b/scripts_python/tests/test_all_steps_learn2track.py index 8155300e..20f6346d 100644 --- a/scripts_python/tests/test_all_steps_learn2track.py +++ b/scripts_python/tests/test_all_steps_learn2track.py @@ -47,6 +47,7 @@ def test_training(script_runner, experiments_path): input_group_name, streamline_group_name, '--max_epochs', '1', '--batch_size_training', '5', '--batch_size_validation', '5', + '--dg_key', 'gaussian', '--batch_size_units', 'nb_streamlines', '--max_batches_per_epoch_training', '2', '--max_batches_per_epoch_validation', '1', diff --git a/scripts_python/tests/test_compute_loss_copy_previous.py b/scripts_python/tests/test_compute_loss_copy_previous.py index 8a1eb324..bb413f0d 100644 --- a/scripts_python/tests/test_compute_loss_copy_previous.py +++ b/scripts_python/tests/test_compute_loss_copy_previous.py @@ -36,4 +36,3 @@ def test_running(script_runner, experiments_path): '--pick_at_random', hdf5_file, subj_id, streamline_group_name) assert ret.success -