Skip to content

Commit

Permalink
Merge pull request scil-vital#226 from EmmaRenauld/new_visu_transfo
Browse files Browse the repository at this point in the history
Big update in visu (loss , transformer attention weights)
  • Loading branch information
EmmaRenauld authored Apr 9, 2024
2 parents f8166eb + 5989e17 commit 26540f0
Show file tree
Hide file tree
Showing 44 changed files with 2,779 additions and 1,555 deletions.
11 changes: 11 additions & 0 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def get_mri_data(self, subj_idx: int, group_idx: int,
Contrary to get_volume_verify_cache, this does not send data to
cache for later use.
Parameters
----------
subj_idx: int
The subject id.
group_idx: int
The volume group idx.
load_it: bool
If data is lazy, get the volume as a LazyMRIData (False) or load it
as non-lazy (if True).
"""
if self.subjs_data_list.is_lazy:
if load_it:
Expand Down Expand Up @@ -484,6 +494,7 @@ def load_data(self, load_training=True, load_validation=True,
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

self.streamline_groups = list(self.streamline_groups)
group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
Expand Down
7 changes: 5 additions & 2 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def add_arg_existing_experiment_path(p: ArgumentParser):
help='Path to the directory containing the experiment.\n'
'(Should contain a model subdir with a file \n'
'parameters.json and a file best_model_state.pkl.)')
p.add_argument('--use_latest_epoch', action='store_true',
help="If true, use model at latest epoch rather than "
"default (best model).")


def add_memory_args(p: ArgumentParser, add_lazy_options=False,
Expand All @@ -44,8 +47,8 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
"with --processes.")
help="If set, use GPU for processing. Cannot be used together "
"with \noption --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
help="If set, use GPU for processing.")
Expand Down
20 changes: 20 additions & 0 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):

def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
"""
Parameters
----------
outputs: List[Tensor]
Your model's outputs
target_streamlines: List[Tensor]
The streamlines. Directions will be computed and formatted based
on child class requirements.
average_results: bool
If true, returns the average over all values.
Returns
-------
If compress_loss or average_results:
Tuple(tensor, n)
The average loss and the n points averaged.
Else:
List[Tensor]
The loss for each point in each streamline.
"""
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does "
"not allow returning non-averaged loss.")
Expand Down
40 changes: 40 additions & 0 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def forward(self, inputs, streamlines):
def compute_loss(self, model_outputs, target_streamlines):
raise NotImplementedError

def merge_batches_outputs(self, all_outputs, new_batch):
"""
To be used at testing time. At training or validation time, outputs are
discarded after each batch; only the loss is measured. At testing time,
it may be necessary to merge batches. The way to do it will depend on
your model's format.
Parameters
----------
all_outputs: Any or None
All previous outputs from previous batches, already combined.
new_batch:
The batch to merge
"""
raise NotImplementedError


class ModelWithNeighborhood(MainModelAbstract):
"""
Expand Down Expand Up @@ -752,3 +768,27 @@ def compute_loss(self, model_outputs: List[Tensor], target_streamlines,
def move_to(self, device):
super().move_to(device)
self.direction_getter.move_to(device)

def merge_batches_outputs(self, all_outputs, new_batch, device=None):
# 1. Send new batch to device
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
new_batch = ([m.to(device=device) for m in new_batch[0]],
[s.to(device=device) for s in new_batch[1]])
else:
new_batch = [a.to(device) for a in new_batch]

# 2. Concat
if all_outputs is None:
return new_batch
else:
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
all_outputs[0].extend(new_batch[0])
all_outputs[1].extend(new_batch[1])
elif 'fisher' in self.direction_getter.key:
raise NotImplementedError
else:
# all_outputs = a list of tensors per streamline.
all_outputs.extend(new_batch)
return all_outputs
7 changes: 2 additions & 5 deletions dwi_ml/models/projects/copy_previous_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,5 @@ def compute_loss(self, model_outputs: List[torch.Tensor],
if self.skip_first_point:
target_streamlines = [t[1:] for t in target_streamlines]

if self._context == 'visu':
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
else:
raise NotImplementedError
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
15 changes: 10 additions & 5 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def __init__(self, experiment_name,
self.instantiate_direction_getter(self.rnn_model.output_size)

def set_context(self, context):
# Training, validation: Used by trainer. Nothing special.
# Tracking: Used by tracker. Returns only the last point.
# Preparing_backward: Used by tracker. Nothing special, but does
# not return only the last point.
# Visu: Nothing special. Used by tester.
assert context in ['training', 'validation', 'tracking', 'visu',
'preparing_backward']
self._context = context
Expand Down Expand Up @@ -253,7 +258,7 @@ def forward(self, x: List[torch.tensor],
"""
# Reminder.
# Correct interpolation and management of points should be done before.
if self._context is None:
if self.context is None:
raise ValueError("Please set context before usage.")

# Right now input is always flattened (interpolation is implemented
Expand All @@ -268,7 +273,7 @@ def forward(self, x: List[torch.tensor],
# Making sure we can use default 'enforce_sorted=True' with packed
# sequences.
unsorted_indices = None
if not self._context == 'tracking':
if not self.context == 'tracking':
# Ordering streamlines per length.
lengths = torch.as_tensor([len(s) for s in x])
_, sorted_indices = torch.sort(lengths, descending=True)
Expand Down Expand Up @@ -351,7 +356,7 @@ def forward(self, x: List[torch.tensor],
x = x + copy_prev_dir

# Unpacking.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (during tracking: keeping as one single tensor.)
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
Expand All @@ -370,7 +375,7 @@ def forward(self, x: List[torch.tensor],
if return_hidden:
# Return the hidden states too. Necessary for the generative
# (tracking) part, done step by step.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (ex: when preparing backward tracking.
# Must also re-sort hidden states.)
if self.rnn_model.rnn_torch_key == 'lstm':
Expand Down Expand Up @@ -402,7 +407,7 @@ def copy_prev_dir(self, dirs):
# Converting the input directions into classes the same way as
# during loss, but convert to one-hot.
# The first previous dir (0) converts to index 0.
if self._context == 'tracking':
if self.context == 'tracking':
if dirs[0].shape[0] == 0:
copy_prev_dir = torch.zeros(
len(dirs),
Expand Down
Loading

0 comments on commit 26540f0

Please sign in to comment.