Skip to content

Commit

Permalink
Fix DPP values saved. WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jan 18, 2024
1 parent d7fc300 commit 4afa173
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 359 deletions.
103 changes: 49 additions & 54 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import logging
from time import time
from typing import Union, List, Tuple, Optional

from dipy.data import get_sphere
Expand All @@ -11,7 +10,8 @@

from dwi_ml.data.processing.streamlines.sos_eos_management import \
add_label_as_last_dim, convert_dirs_to_class
from dwi_ml.data.processing.streamlines.post_processing import compute_directions
from dwi_ml.data.processing.streamlines.post_processing import \
compute_directions
from dwi_ml.data.spheres import TorchSphere
from dwi_ml.models.embeddings import keys_to_embeddings
from dwi_ml.models.main_models import (ModelWithDirectionGetter,
Expand Down Expand Up @@ -78,6 +78,29 @@ def pad_and_stack_batch(data: List[torch.Tensor], pad_first: bool,
return torch.stack(data)


def merge_one_weight_type(weights, new_weights, device):
# Weight is a list per layer of tensors of shape
# nb_streamlines, nb_heads, batch_max_len, batch_max_len
new_weights = [layer_weight.to(device) for layer_weight in new_weights]
new_max_len = new_weights[0].shape[2]

if weights is None:
return new_weights
else:
old_max_len = weights[0].shape[2]

# Padding if necessary. We could pad to max_len, but probably
# heavy for no reason.
pad_w = max(0, new_max_len - old_max_len)
pad_n = max(0, old_max_len - new_max_len)
weights = [torch.cat((
pad(w, (0, pad_w, 0, pad_w)),
pad(n, (0, pad_n, 0, pad_n))),
dim=0) for w, n in zip(weights, new_weights)]

return weights


class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter,
ModelOneInputWithEmbedding):
"""
Expand Down Expand Up @@ -332,8 +355,7 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len):
return mask_future, mask_padding

def forward(self, inputs: List[torch.tensor],
input_streamlines: List[torch.tensor] = None,
average_heads=False):
input_streamlines: List[torch.tensor] = None):
"""
Params
------
Expand All @@ -351,9 +373,6 @@ def forward(self, inputs: List[torch.tensor],
adequately masked to hide future positions. The last direction is
not used.
- As target during training. The whole sequence is used.
average_heads: bool
If return_weights, you may choose to average the weights from
different heads together.
Returns
-------
Expand Down Expand Up @@ -394,12 +413,7 @@ def forward(self, inputs: List[torch.tensor],
use_padding = not np.all(input_lengths == input_lengths[0])
batch_max_len = np.max(input_lengths)
if CLEAR_CACHE:
now = time()
logging.debug("Transformer: Maximal length in batch is {}"
.format(batch_max_len))
torch.torch.cuda.empty_cache()
now2 = time()
logging.debug("Cleared cache in {} secs.".format(now2 - now))

# ----------- Prepare masks
masks = self._prepare_masks(input_lengths, use_padding, batch_max_len)
Expand All @@ -422,7 +436,7 @@ def forward(self, inputs: List[torch.tensor],

# 2. Main transformer
outputs, weights = self._run_main_layer_forward(
data, masks, return_weights, average_heads)
data, masks, return_weights)

# Here, data = one tensor, padded.
# Unpad now and either
Expand Down Expand Up @@ -485,6 +499,9 @@ def forward(self, inputs: List[torch.tensor],
outputs = list(torch.split(outputs, list(input_lengths)))

if return_weights:
# Padding weights to max length, else we won't be able to stack
# outputs. This way, all weights are a list, per layer, of
# tensors of shape [nb_streamlines, nb_heads, max_len, max_len]
return outputs, weights

return outputs
Expand All @@ -498,8 +515,7 @@ def _run_embeddings(self, data, use_padding, batch_max_len):
def _run_position_encoding(self, data):
raise NotImplementedError

def _run_main_layer_forward(self, data, masks, return_weights,
average_heads):
def _run_main_layer_forward(self, data, masks, return_weights):
raise NotImplementedError

def _run_input_embedding(self, inputs, use_padding, batch_max_len):
Expand All @@ -522,6 +538,7 @@ def merge_batches_outputs(self, all_outputs, new_batch, device=None):
outputs, weights = None, None
else:
outputs, weights = all_outputs

new_outputs = super().merge_batches_outputs(outputs, new_outputs,
device)
new_weights = self.merge_batches_weights(weights, new_weights,
Expand Down Expand Up @@ -582,26 +599,21 @@ def _run_position_encoding(self, inputs):
inputs = self.dropout(inputs)
return inputs

def _run_main_layer_forward(self, inputs, masks,
return_weights, average_heads):
def _run_main_layer_forward(self, inputs, masks, return_weights):
# Encoder only.

# mask_future, mask_padding = masks
outputs, sa_weights = self.modified_torch_transformer(
src=inputs, mask=masks[0], src_key_padding_mask=masks[1],
return_weights=return_weights, average_heads=average_heads)
return_weights=return_weights)

return outputs, (sa_weights,)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a single attention tensor (encoder): a tuple of 1.
new_weights = [a.to(device) for a in new_weights[0]]

# Weights is a single attention tensor (encoder): a tuple of 1.
if weights is None:
return (new_weights,)
else:
weights.extend(new_weights)
return (weights,)
weights = (None,)
return (merge_one_weight_type(weights[0], new_weights[0], device), )


class AbstractTransformerModelWithTarget(AbstractTransformerModel):
Expand Down Expand Up @@ -703,8 +715,7 @@ def _run_embeddings(self, data, use_padding, batch_max_len):
def _run_position_encoding(self, data):
raise NotImplementedError

def _run_main_layer_forward(self, data, masks, return_weights,
average_heads):
def _run_main_layer_forward(self, data, masks, return_weights):
raise NotImplementedError

def format_prev_dir_(self, dirs):
Expand Down Expand Up @@ -871,8 +882,7 @@ def _run_position_encoding(self, data):

return inputs, targets

def _run_main_layer_forward(self, data, masks,
return_weights, average_heads):
def _run_main_layer_forward(self, data, masks, return_weights):
"""Original Main transformer
Returns
Expand All @@ -890,24 +900,15 @@ def _run_main_layer_forward(self, data, masks,
src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0],
src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1],
memory_key_padding_mask=masks[1],
return_weights=return_weights, average_heads=average_heads)
return_weights=return_weights)
return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a Tuple[encoder, decoder, cross]
new_weights_e, new_weights_d, new_weights_c = new_weights
new_weights_e = [a.to(device) for a in new_weights_e]
new_weights_d = [a.to(device) for a in new_weights_d]
new_weights_c = [a.to(device) for a in new_weights_c]

if weights is None:
return new_weights_e, new_weights_d, new_weights_c
else:
weights_e, weights_d, weights_c = weights
weights_e.extend(new_weights_e)
weights_d.extend(new_weights_d)
weights_c.extend(new_weights_c)
return weights_e, weights_d, weights_c
weights = (None, None, None)
return (merge_one_weight_type(weights[0], new_weights[0], device),
merge_one_weight_type(weights[1], new_weights[1], device),
merge_one_weight_type(weights[1], new_weights[1], device))


class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget):
Expand Down Expand Up @@ -972,24 +973,18 @@ def _run_position_encoding(self, data):
data = self.dropout(data)
return data

def _run_main_layer_forward(self, concat_s_t, masks,
return_weights, average_heads):
def _run_main_layer_forward(self, concat_s_t, masks, return_weights):
# Encoder only.

# mask_future, mask_padding = masks
outputs, sa_weights = self.modified_torch_transformer(
src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1],
return_weights=return_weights, average_heads=average_heads)
return_weights=return_weights)

return outputs, (sa_weights,)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a single attention tensor (encoder): a tuple of 1.
new_weights = [a.to(device) for a in new_weights[0]]

# Weights is a single attention tensor (encoder): a tuple of 1.
if weights is None:
return (new_weights,)
else:
weights.extend(new_weights)
return (weights,)

weights = (None,)
return (merge_one_weight_type(weights[0], new_weights[0], device), )
Loading

0 comments on commit 4afa173

Please sign in to comment.