Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big update in visu (loss , transformer attention weights) #226

Merged
merged 22 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def get_mri_data(self, subj_idx: int, group_idx: int,

Contrary to get_volume_verify_cache, this does not send data to
cache for later use.

Parameters
----------
subj_idx: int
The subject id.
group_idx: int
The volume group idx.
load_it: bool
If data is lazy, get the volume as a LazyMRIData (False) or load it
as non-lazy (if True).
"""
if self.subjs_data_list.is_lazy:
if load_it:
Expand Down Expand Up @@ -484,6 +494,7 @@ def load_data(self, load_training=True, load_validation=True,
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

self.streamline_groups = list(self.streamline_groups)
group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
Expand Down
7 changes: 5 additions & 2 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def add_arg_existing_experiment_path(p: ArgumentParser):
help='Path to the directory containing the experiment.\n'
'(Should contain a model subdir with a file \n'
'parameters.json and a file best_model_state.pkl.)')
p.add_argument('--use_latest_epoch', action='store_true',
help="If true, use model at latest epoch rather than "
"default (best model).")


def add_memory_args(p: ArgumentParser, add_lazy_options=False,
Expand All @@ -44,8 +47,8 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
"with --processes.")
help="If set, use GPU for processing. Cannot be used together "
"with \noption --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
help="If set, use GPU for processing.")
Expand Down
20 changes: 20 additions & 0 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):

def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
"""
Parameters
----------
outputs: List[Tensor]
Your model's outputs
target_streamlines: List[Tensor]
The streamlines. Directions will be computed and formatted based
on child class requirements.
average_results: bool
If true, returns the average over all values.

Returns
-------
If compress_loss or average_results:
Tuple(tensor, n)
The average loss and the n points averaged.
Else:
List[Tensor]
The loss for each point in each streamline.
"""
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does "
"not allow returning non-averaged loss.")
Expand Down
40 changes: 40 additions & 0 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def forward(self, inputs, streamlines):
def compute_loss(self, model_outputs, target_streamlines):
raise NotImplementedError

def merge_batches_outputs(self, all_outputs, new_batch):
"""
To be used at testing time. At training or validation time, outputs are
discarded after each batch; only the loss is measured. At testing time,
it may be necessary to merge batches. The way to do it will depend on
your model's format.

Parameters
----------
all_outputs: Any or None
All previous outputs from previous batches, already combined.
new_batch:
The batch to merge
"""
raise NotImplementedError


class ModelWithNeighborhood(MainModelAbstract):
"""
Expand Down Expand Up @@ -752,3 +768,27 @@ def compute_loss(self, model_outputs: List[Tensor], target_streamlines,
def move_to(self, device):
super().move_to(device)
self.direction_getter.move_to(device)

def merge_batches_outputs(self, all_outputs, new_batch, device=None):
# 1. Send new batch to device
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
new_batch = ([m.to(device=device) for m in new_batch[0]],
[s.to(device=device) for s in new_batch[1]])
else:
new_batch = [a.to(device) for a in new_batch]

# 2. Concat
if all_outputs is None:
return new_batch
else:
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
all_outputs[0].extend(new_batch[0])
all_outputs[1].extend(new_batch[1])
elif 'fisher' in self.direction_getter.key:
raise NotImplementedError
else:
# all_outputs = a list of tensors per streamline.
all_outputs.extend(new_batch)
return all_outputs
7 changes: 2 additions & 5 deletions dwi_ml/models/projects/copy_previous_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,5 @@ def compute_loss(self, model_outputs: List[torch.Tensor],
if self.skip_first_point:
target_streamlines = [t[1:] for t in target_streamlines]

if self._context == 'visu':
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
else:
raise NotImplementedError
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
15 changes: 10 additions & 5 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def __init__(self, experiment_name,
self.instantiate_direction_getter(self.rnn_model.output_size)

def set_context(self, context):
# Training, validation: Used by trainer. Nothing special.
# Tracking: Used by tracker. Returns only the last point.
# Preparing_backward: Used by tracker. Nothing special, but does
# not return only the last point.
# Visu: Nothing special. Used by tester.
assert context in ['training', 'validation', 'tracking', 'visu',
'preparing_backward']
self._context = context
Expand Down Expand Up @@ -253,7 +258,7 @@ def forward(self, x: List[torch.tensor],
"""
# Reminder.
# Correct interpolation and management of points should be done before.
if self._context is None:
if self.context is None:
raise ValueError("Please set context before usage.")

# Right now input is always flattened (interpolation is implemented
Expand All @@ -268,7 +273,7 @@ def forward(self, x: List[torch.tensor],
# Making sure we can use default 'enforce_sorted=True' with packed
# sequences.
unsorted_indices = None
if not self._context == 'tracking':
if not self.context == 'tracking':
# Ordering streamlines per length.
lengths = torch.as_tensor([len(s) for s in x])
_, sorted_indices = torch.sort(lengths, descending=True)
Expand Down Expand Up @@ -351,7 +356,7 @@ def forward(self, x: List[torch.tensor],
x = x + copy_prev_dir

# Unpacking.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (during tracking: keeping as one single tensor.)
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
Expand All @@ -370,7 +375,7 @@ def forward(self, x: List[torch.tensor],
if return_hidden:
# Return the hidden states too. Necessary for the generative
# (tracking) part, done step by step.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (ex: when preparing backward tracking.
# Must also re-sort hidden states.)
if self.rnn_model.rnn_torch_key == 'lstm':
Expand Down Expand Up @@ -402,7 +407,7 @@ def copy_prev_dir(self, dirs):
# Converting the input directions into classes the same way as
# during loss, but convert to one-hot.
# The first previous dir (0) converts to index 0.
if self._context == 'tracking':
if self.context == 'tracking':
if dirs[0].shape[0] == 0:
copy_prev_dir = torch.zeros(
len(dirs),
Expand Down
Loading
Loading