Skip to content

Commit

Permalink
Merge pull request scil-vital#233 from EmmaRenauld/fix_transformer_gauss
Browse files Browse the repository at this point in the history
Fix transformers with Gaussian/Fisher usage
  • Loading branch information
EmmaRenauld authored Apr 22, 2024
2 parents 26540f0 + 1b49d7b commit 27177d8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
9 changes: 8 additions & 1 deletion dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,14 @@ def forward(self, inputs: List[torch.tensor],

# Splitting back. During tracking: only one point per streamline.
if self.context != 'tracking':
outputs = list(torch.split(outputs, list(input_lengths)))
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
x, x2 = outputs
x = list(torch.split(x, list(input_lengths)))
x2 = list(torch.split(x2, list(input_lengths)))
outputs = (x, x2)
else:
outputs = list(torch.split(outputs, list(input_lengths)))

if return_weights:
# Padding weights to max length, else we won't be able to stack
Expand Down
1 change: 1 addition & 0 deletions scripts_python/tests/test_all_steps_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_execution(script_runner, experiments_path):
ret = script_runner.run('tt_train_model.py',
experiments_path, experiment_name, hdf5_file,
input_group_name, streamline_group_name,
'--dg_key', 'gaussian',
'--model', 'TTS',
'--max_epochs', '1', '--batch_size_training', '5',
'--batch_size_units', 'nb_streamlines',
Expand Down

0 comments on commit 27177d8

Please sign in to comment.