From 1b49d7bea0f669f5d7e292c75c70cf038d51fd0f Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 22 Apr 2024 13:27:25 -0400 Subject: [PATCH] Fix transformers with Gaussian/Fisher usage --- dwi_ml/models/projects/transformer_models.py | 9 ++++++++- scripts_python/tests/test_all_steps_tts.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 3714a049..1ecc2664 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -497,7 +497,14 @@ def forward(self, inputs: List[torch.tensor], # Splitting back. During tracking: only one point per streamline. if self.context != 'tracking': - outputs = list(torch.split(outputs, list(input_lengths))) + if 'gaussian' in self.dg_key or 'fisher' in self.dg_key: + # Separating mean, sigmas (gaussian) or mean, kappa (fisher) + x, x2 = outputs + x = list(torch.split(x, list(input_lengths))) + x2 = list(torch.split(x2, list(input_lengths))) + outputs = (x, x2) + else: + outputs = list(torch.split(outputs, list(input_lengths))) if return_weights: # Padding weights to max length, else we won't be able to stack diff --git a/scripts_python/tests/test_all_steps_tts.py b/scripts_python/tests/test_all_steps_tts.py index fe13d6ba..3acfc5e1 100644 --- a/scripts_python/tests/test_all_steps_tts.py +++ b/scripts_python/tests/test_all_steps_tts.py @@ -41,6 +41,7 @@ def test_execution(script_runner, experiments_path): ret = script_runner.run('tt_train_model.py', experiments_path, experiment_name, hdf5_file, input_group_name, streamline_group_name, + '--dg_key', 'gaussian', '--model', 'TTS', '--max_epochs', '1', '--batch_size_training', '5', '--batch_size_units', 'nb_streamlines',