Skip to content

Commit

Permalink
All tests passing! Major refactor of visu done
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jan 18, 2024
1 parent 6855a48 commit d7fc300
Show file tree
Hide file tree
Showing 35 changed files with 1,892 additions and 1,003 deletions.
11 changes: 11 additions & 0 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def get_mri_data(self, subj_idx: int, group_idx: int,
Contrary to get_volume_verify_cache, this does not send data to
cache for later use.
Parameters
----------
subj_idx: int
The subject id.
group_idx: int
The volume group idx.
load_it: bool
If data is lazy, get the volume as a LazyMRIData (False) or load it
as non-lazy (if True).
"""
if self.subjs_data_list.is_lazy:
if load_it:
Expand Down Expand Up @@ -484,6 +494,7 @@ def load_data(self, load_training=True, load_validation=True,
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

self.streamline_groups = list(self.streamline_groups)
group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
Expand Down
20 changes: 20 additions & 0 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):

def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
"""
Parameters
----------
outputs: List[Tensor]
Your model's outputs
target_streamlines: List[Tensor]
The streamlines. Directions will be computed and formatted based
on child class requirements.
average_results: bool
If true, returns the average over all values.
Returns
-------
If compress_loss or average_results:
Tuple(tensor, n)
The average loss and the n points averaged.
Else:
List[Tensor]
The loss for each point in each streamline.
"""
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")
Expand Down
40 changes: 40 additions & 0 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def forward(self, inputs, streamlines):
def compute_loss(self, model_outputs, target_streamlines):
raise NotImplementedError

def merge_batches_outputs(self, all_outputs, new_batch):
"""
To be used at testing time. At training or validation time, outputs are
discarded after each batch; only the loss is measured. At testing time,
it may be necessary to merge batches. The way to do it will depend on
your model's format.
Parameters
----------
all_outputs: Any or None
All previous outputs from previous batches, already combined.
new_batch:
The batch to merge
"""
raise NotImplementedError


class ModelWithNeighborhood(MainModelAbstract):
"""
Expand Down Expand Up @@ -752,3 +768,27 @@ def compute_loss(self, model_outputs: List[Tensor], target_streamlines,
def move_to(self, device):
super().move_to(device)
self.direction_getter.move_to(device)

def merge_batches_outputs(self, all_outputs, new_batch, device=None):
# 1. Send new batch to device
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
new_batch = ([m.to(device=device) for m in new_batch[0]],
[s.to(device=device) for s in new_batch[1]])
else:
new_batch = [a.to(device) for a in new_batch]

# 2. Concat
if all_outputs is None:
return new_batch
else:
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
all_outputs[0].extend(new_batch[0])
all_outputs[1].extend(new_batch[1])
elif 'fisher' in self.direction_getter.key:
raise NotImplementedError
else:
# all_outputs = a list of tensors per streamline.
all_outputs.extend(new_batch)
return all_outputs
7 changes: 2 additions & 5 deletions dwi_ml/models/projects/copy_previous_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,5 @@ def compute_loss(self, model_outputs: List[torch.Tensor],
if self.skip_first_point:
target_streamlines = [t[1:] for t in target_streamlines]

if self._context == 'visu':
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
else:
raise NotImplementedError
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
15 changes: 10 additions & 5 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def __init__(self, experiment_name,
self.instantiate_direction_getter(self.rnn_model.output_size)

def set_context(self, context):
# Training, validation: Used by trainer. Nothing special.
# Tracking: Used by tracker. Returns only the last point.
# Preparing_backward: Used by tracker. Nothing special, but does
# not return only the last point.
# Visu: Nothing special. Used by tester.
assert context in ['training', 'validation', 'tracking', 'visu',
'preparing_backward']
self._context = context
Expand Down Expand Up @@ -253,7 +258,7 @@ def forward(self, x: List[torch.tensor],
"""
# Reminder.
# Correct interpolation and management of points should be done before.
if self._context is None:
if self.context is None:
raise ValueError("Please set context before usage.")

# Right now input is always flattened (interpolation is implemented
Expand All @@ -268,7 +273,7 @@ def forward(self, x: List[torch.tensor],
# Making sure we can use default 'enforce_sorted=True' with packed
# sequences.
unsorted_indices = None
if not self._context == 'tracking':
if not self.context == 'tracking':
# Ordering streamlines per length.
lengths = torch.as_tensor([len(s) for s in x])
_, sorted_indices = torch.sort(lengths, descending=True)
Expand Down Expand Up @@ -351,7 +356,7 @@ def forward(self, x: List[torch.tensor],
x = x + copy_prev_dir

# Unpacking.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (during tracking: keeping as one single tensor.)
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
Expand All @@ -370,7 +375,7 @@ def forward(self, x: List[torch.tensor],
if return_hidden:
# Return the hidden states too. Necessary for the generative
# (tracking) part, done step by step.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (ex: when preparing backward tracking.
# Must also re-sort hidden states.)
if self.rnn_model.rnn_torch_key == 'lstm':
Expand Down Expand Up @@ -402,7 +407,7 @@ def copy_prev_dir(self, dirs):
# Converting the input directions into classes the same way as
# during loss, but convert to one-hot.
# The first previous dir (0) converts to index 0.
if self._context == 'tracking':
if self.context == 'tracking':
if dirs[0].shape[0] == 0:
copy_prev_dir = torch.zeros(
len(dirs),
Expand Down
82 changes: 74 additions & 8 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,14 @@ def _load_params(cls, model_dir):
return params

def set_context(self, context):
assert context in ['training', 'validation', 'tracking', 'visu']
# Training, validation: Used by trainer. Nothing special.
# Tracking: Used by tracker. Returns only the last point.
# Preparing_backward: Used by tracker. Nothing special, but does
# not return only the last point.
# Visu: Nothing special. Used by tester.
# Visu_weights: Returns the weights too.
assert context in ['training', 'validation', 'tracking',
'visu', 'visu_weights']
self._context = context

def _generate_future_mask(self, sz):
Expand Down Expand Up @@ -326,7 +333,6 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len):

def forward(self, inputs: List[torch.tensor],
input_streamlines: List[torch.tensor] = None,
return_weights=False,
average_heads=False):
"""
Params
Expand All @@ -345,8 +351,6 @@ def forward(self, inputs: List[torch.tensor],
adequately masked to hide future positions. The last direction is
not used.
- As target during training. The whole sequence is used.
return_weights: bool
If true, returns the weights of the attention layers.
average_heads: bool
If return_weights, you may choose to average the weights from
different heads together.
Expand All @@ -359,11 +363,15 @@ def forward(self, inputs: List[torch.tensor],
[total nb points all streamlines, out size]
- During tracking: [nb streamlines * 1, out size]
weights: Tuple
If return_weights: The weights (depending on the child model)
If context is 'visu': The weights (depending on the child model)
"""
if self._context is None:
if self.context is None:
raise ValueError("Please set context before usage.")

return_weights = False
if self.context == 'visu_weights':
return_weights = True

# ----------- Checks
if input_streamlines is not None:
# If streamlines are necessary (depending on child class):
Expand Down Expand Up @@ -425,7 +433,7 @@ def forward(self, inputs: List[torch.tensor],
# input size to dg = [nb points total, d_model]
# final output size = [nb points total, regression or
# classification output size]
if self._context == 'tracking':
if self.context == 'tracking':
# No need to actually unpad, we only take the last unpadded point
# Ignoring both the beginning of the streamline (redundant from
# previous tracking step) and the end of the streamline (padded
Expand Down Expand Up @@ -473,7 +481,7 @@ def forward(self, inputs: List[torch.tensor],
outputs = constant_output + outputs

# Splitting back. During tracking: only one point per streamline.
if self._context != 'tracking':
if self.context != 'tracking':
outputs = list(torch.split(outputs, list(input_lengths)))

if return_weights:
Expand Down Expand Up @@ -506,6 +514,27 @@ def _run_input_embedding(self, inputs, use_padding, batch_max_len):
inputs = self.input_embedding_layer(inputs)
return inputs

def merge_batches_outputs(self, all_outputs, new_batch, device=None):
if self.context == 'visu_weights':
new_outputs, new_weights = new_batch

if all_outputs is None:
outputs, weights = None, None
else:
outputs, weights = all_outputs
new_outputs = super().merge_batches_outputs(outputs, new_outputs,
device)
new_weights = self.merge_batches_weights(weights, new_weights,
device)
return new_outputs, new_weights

else:
# No weights.
return super().merge_batches_outputs(all_outputs, new_batch)

def merge_batches_weights(self, weights, new_weights, device):
raise NotImplementedError


class TransformerSrcOnlyModel(AbstractTransformerModel):
def __init__(self, **kw):
Expand Down Expand Up @@ -564,6 +593,16 @@ def _run_main_layer_forward(self, inputs, masks,

return outputs, (sa_weights,)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a single attention tensor (encoder): a tuple of 1.
new_weights = [a.to(device) for a in new_weights[0]]

if weights is None:
return (new_weights,)
else:
weights.extend(new_weights)
return (weights,)


class AbstractTransformerModelWithTarget(AbstractTransformerModel):
def __init__(self,
Expand Down Expand Up @@ -854,6 +893,22 @@ def _run_main_layer_forward(self, data, masks,
return_weights=return_weights, average_heads=average_heads)
return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a Tuple[encoder, decoder, cross]
new_weights_e, new_weights_d, new_weights_c = new_weights
new_weights_e = [a.to(device) for a in new_weights_e]
new_weights_d = [a.to(device) for a in new_weights_d]
new_weights_c = [a.to(device) for a in new_weights_c]

if weights is None:
return new_weights_e, new_weights_d, new_weights_c
else:
weights_e, weights_d, weights_c = weights
weights_e.extend(new_weights_e)
weights_d.extend(new_weights_d)
weights_c.extend(new_weights_c)
return weights_e, weights_d, weights_c


class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget):
"""
Expand Down Expand Up @@ -927,3 +982,14 @@ def _run_main_layer_forward(self, concat_s_t, masks,
return_weights=return_weights, average_heads=average_heads)

return outputs, (sa_weights,)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a single attention tensor (encoder): a tuple of 1.
new_weights = [a.to(device) for a in new_weights[0]]

if weights is None:
return (new_weights,)
else:
weights.extend(new_weights)
return (weights,)

20 changes: 19 additions & 1 deletion dwi_ml/models/projects/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dwi_ml.models.positional_encoding import (
keys_to_positional_encodings)
from dwi_ml.models.projects.transformer_models import (
AbstractTransformerModel)
AbstractTransformerModel, OriginalTransformerModel,
TransformerSrcAndTgtModel, TransformerSrcOnlyModel)

sphere_choices = ['symmetric362', 'symmetric642', 'symmetric724',
'repulsion724', 'repulsion100', 'repulsion200']
Expand Down Expand Up @@ -118,3 +119,20 @@ def add_transformers_model_args(p):

g = p.add_argument_group("Output")
AbstractTransformerModel.add_args_tracking_model(g)


def find_transformer_class(model_type):
"""
model_type: returned by verify_which_model_in_path.
"""
if model_type == 'OriginalTransformerModel':
model_cls = OriginalTransformerModel
elif model_type == 'TransformerSrcAndTgtModel':
model_cls = TransformerSrcAndTgtModel
elif model_type == 'TransformerSrcOnlyModel':
model_cls = TransformerSrcOnlyModel
else:
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return model_cls
23 changes: 0 additions & 23 deletions dwi_ml/testing/projects/copy_prev_dirs_tester.py

This file was deleted.

Loading

0 comments on commit d7fc300

Please sign in to comment.