From d3c552f89eaa8087c574314f92d6059fbbeddb61 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 22 May 2024 15:59:42 -0400 Subject: [PATCH] Copy Transformer model from latest version. New is_causal arg --- dwi_ml/models/projects/transformer_models.py | 43 ++- .../models/projects/transformer_sublayers.py | 304 ++++++++++++++++ .../models/utils/transformers_from_torch.py | 337 ++++++++---------- requirements.txt | 7 - 4 files changed, 474 insertions(+), 217 deletions(-) create mode 100644 dwi_ml/models/projects/transformer_sublayers.py diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 1ecc2664..4409a963 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -574,8 +574,17 @@ def __init__(self, **kw): self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - self.modified_torch_transformer = ModifiedTransformerEncoder( - main_layer_encoder, self.n_layers_e, norm=None) + + # Receiving weird warning: enable_nested_tensor is True, + # but self.use_nested_tensor is False because encoder_layer.norm_first + # was True. + enable_nested = False if self.norm_first else True + + # Note about norm: this is a final normalization step. Not linked to + # the normalization decided with self.norm_first. + self.transformer_encoder_only = ModifiedTransformerEncoder( + main_layer_encoder, self.n_layers_e, norm=None, + enable_nested_tensor=enable_nested) @property def d_model(self): @@ -611,9 +620,9 @@ def _run_main_layer_forward(self, inputs, masks, return_weights): # Encoder only. # mask_future, mask_padding = masks - outputs, sa_weights = self.modified_torch_transformer( + outputs, sa_weights = self.transformer_encoder_only( src=inputs, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights) + is_causal=True, return_weights=return_weights) return outputs, (sa_weights,) @@ -844,8 +853,17 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw): dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - encoder = ModifiedTransformerEncoder(encoder_layer, self.n_layers_e, - norm=None) + + # Receiving weird warning: enable_nested_tensor is True, + # but self.use_nested_tensor is False because encoder_layer.norm_first + # was True. + enable_nested = False if self.norm_first else True + + # Note about norm: this is a final normalization step. Not linked to + # the normalization decided with self.norm_first. + encoder = ModifiedTransformerEncoder( + encoder_layer, self.n_layers_e, norm=None, + enable_nested_tensor=enable_nested) # Decoder decoder_layer = ModifiedTransformerDecoderLayer( @@ -856,7 +874,7 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw): decoder = ModifiedTransformerDecoder(decoder_layer, n_layers_d, norm=None) - self.modified_torch_transformer = ModifiedTransformer( + self.transformer = ModifiedTransformer( self.d_model, self.nheads, self.n_layers_e, n_layers_d, self.ffnn_hidden_size, self.dropout_rate, self.activation, encoder, decoder, batch_first=True, @@ -904,11 +922,12 @@ def _run_main_layer_forward(self, data, masks, return_weights): # embed_x, embed_t = data # mask_future, mask_padding = masks outputs, sa_weights_encoder, sa_weights_decoder, mha_weights = \ - self.modified_torch_transformer( + self.transformer( src=data[0], tgt=data[1], src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0], src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1], - memory_key_padding_mask=masks[1], + memory_key_padding_mask=masks[1], src_is_causal=True, + tgt_is_causal=True, memory_is_causal=True, return_weights=return_weights) return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights) @@ -955,7 +974,7 @@ def __init__(self, **kw): self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - self.modified_torch_transformer = ModifiedTransformerEncoder( + self.transformer_encoder_only = ModifiedTransformerEncoder( main_layer_encoder, self.n_layers_e, norm=None) @property @@ -987,9 +1006,9 @@ def _run_main_layer_forward(self, concat_s_t, masks, return_weights): # Encoder only. # mask_future, mask_padding = masks - outputs, sa_weights = self.modified_torch_transformer( + outputs, sa_weights = self.transformer_encoder_only( src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights) + is_causal=True, return_weights=return_weights) return outputs, (sa_weights,) diff --git a/dwi_ml/models/projects/transformer_sublayers.py b/dwi_ml/models/projects/transformer_sublayers.py new file mode 100644 index 00000000..e2c65731 --- /dev/null +++ b/dwi_ml/models/projects/transformer_sublayers.py @@ -0,0 +1,304 @@ +""" +Child classes of Torch Transformers. Changes are: + +- EncoderLayer: Idem +- DecoderLayer: Idem + +""" +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import (TransformerDecoderLayer, TransformerEncoderLayer, + MultiheadAttention, Parameter) + +logger = logging.getLogger('model_logger') + + +def do_not_share_linear_weights(attn: MultiheadAttention, d_model): + """ + I added a request for this parameter to be accessible. + https://github.com/pytorch/pytorch/issues/92990 + + Copied from MultiheadAttention's init method + """ + + factory_kwargs = {'device': None, 'dtype': None} + + # Overriding some parameters in the self attention. + # Ugly but.... Torch does not have a parameter to NOT share linear + # weights. In their code, their only NOT share weights when dimensions + # are not the same. This is not our case. This is saved in their + # parameter _qkv_same_embed_dim. By changing this, we change their + # forward call to the MultiHeadAttention in self.self_attn. + attn._qkv_same_embed_dim = False + attn.q_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.k_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.v_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.register_parameter('in_proj_weight', None) + attn._reset_parameters() + + +class ModifiedTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, d_model, nhead, **kw): + super().__init__(d_model, nhead, **kw) + + do_not_share_linear_weights(self.self_attn, d_model) + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + why_not_sparsity_fast_path = '' + if not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first: + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif not self.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif not (self.norm1.eps == self.norm2.eps): + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None): + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" + elif self.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + _supported_device_type = ["cpu", "cuda", + torch.utils.backend_registration._privateuse1_backend_name] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all((x.device.type in _supported_device_type) for x in + tensor_args): + why_not_sparsity_fast_path = ( + "some Tensor argument's device is neither one of " + f"{_supported_device_type}") + elif torch.is_grad_enabled() and any( + x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks(src_mask, + src_key_padding_mask, + src) + # MODIFIED: + if return_weights: + raise NotImplementedError( + "Did not expect to reach here. Not ready to return " + "weights. Please contact dwi_ml developpers") + return torch._transformer_encoder_layer_fwd( + src, + self.self_attn.embed_dim, + self.self_attn.num_heads, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + self.norm_first, + self.norm1.eps, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + merged_mask, + mask_type, + ) + + x = src + if self.norm_first: + # Norm, SA, Add, Norm, FF, Add + sa, sa_weights = self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask, + is_causal=is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = x + sa + x = x + self._ff_block(self.norm2(x)) + else: + # SA, Add, Norm, FF, Add, Norm + sa, sa_weights = self._sa_block( + x, src_mask, src_key_padding_mask, is_causal=is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm1(x + sa) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_weights + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + x, weights = self.self_attn( + x, x, x, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + return self.dropout1(x), weights + + +class ModifiedTransformerDecoderLayer(TransformerDecoderLayer): + """ + Decoder Layer, in the case where we do not have a start of sequence (SOS) + token, and our mask contains only -inf for the first position. Output of + self-attention becomes nan after the softmax step. Setting to 0. + + Also, now returning attention weights. + """ + def __init__(self, d_model, nhead, **kw): + super().__init__(d_model, nhead, **kw) + + do_not_share_linear_weights(self.self_attn, d_model) + do_not_share_linear_weights(self.multihead_attn, d_model) + + def forward(self, tgt: Tensor, memory: Tensor, + tgt_mask: Tensor = None, memory_mask: Tensor = None, + tgt_key_padding_mask: Tensor = None, + memory_key_padding_mask: Tensor = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights + converts nan to 0. + Weights are None if return_weights is False. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = tgt + if self.norm_first: + # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add + sa, sa_weights = self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = x + sa + + mha, mha_weights = self._mha_block( + self.norm2(x), memory, memory_mask, memory_key_padding_mask, + memory_is_causal, + # Nre args: + return_weights=return_weights, average_heads=average_heads) + x = x + mha + x = x + self._ff_block(self.norm3(x)) + else: + # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm. + sa, sa_weights = self._sa_block( + x, tgt_mask, tgt_key_padding_mask, tgt_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm1(x + sa) + + mha, mha_weights = self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, + memory_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm2(x + mha) + x = self.norm3(x + self._ff_block(x)) + + return x, mha_weights, sa_weights + + # self-attention block + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights. + """ + x, weights = self.self_attn( + x, x, x, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + return self.dropout1(x), weights + + # multihead attention block + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Can now use need_weight = True. + """ + x = self.multihead_attn( + x, mem, mem, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + if return_weights: + x, weights = x + else: + weights = None + + return self.dropout2(x[0]), weights diff --git a/dwi_ml/models/utils/transformers_from_torch.py b/dwi_ml/models/utils/transformers_from_torch.py index a9972c6b..7710b871 100644 --- a/dwi_ml/models/utils/transformers_from_torch.py +++ b/dwi_ml/models/utils/transformers_from_torch.py @@ -6,82 +6,129 @@ to decide if we want to share the linear weights for Q, K, V. - Encoder: Idem - Decoder: Idem -- EncoderLayer: Idem -- DecoderLayer: Idem """ import logging from typing import Optional import torch +import torch.nn.functional as F from torch import Tensor -from torch.nn import (Transformer, - TransformerDecoderLayer, TransformerDecoder, - TransformerEncoderLayer, TransformerEncoder, - MultiheadAttention, Parameter) +from torch.nn import Transformer, TransformerDecoder, TransformerEncoder +from torch.nn.modules.transformer import _get_seq_len, _detect_is_causal_mask from dwi_ml.experiment_utils.memory import log_gpu_memory_usage +from dwi_ml.models.projects.transformer_sublayers import \ + ModifiedTransformerDecoderLayer, ModifiedTransformerEncoderLayer logger = logging.getLogger('model_logger') -class ModifiedTransformer(Transformer): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - - def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None, - tgt_mask: Tensor = None, memory_mask: Tensor = None, - src_key_padding_mask: Tensor = None, - tgt_key_padding_mask: Tensor = None, - memory_key_padding_mask: Tensor = None, - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Now returns weights. - """ - logger.debug("Entering main Transformer's forward.") - log_gpu_memory_usage(logger) - memory, sa_weights_encoder = self.encoder( - src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - - output, sa_weights_decoder, mha_weights = self.decoder( - tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - - return output, sa_weights_encoder, sa_weights_decoder, mha_weights - - class ModifiedTransformerEncoder(TransformerEncoder): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) + def __init__(self, encoder_layer, *args, **kw): + if not isinstance(encoder_layer, ModifiedTransformerEncoderLayer): + raise ValueError("Encoder layer should be of type {}. Got {}" + .format(ModifiedTransformerEncoderLayer.__name__, + type(encoder_layer))) + super().__init__(encoder_layer, *args, **kw) def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + # New args: return_weights=False, average_heads=False): """ Copy-pasted from torch. Now returns weights. - Erased all the fast-path check: it is not used anyway if it is - training, and if we supply both src_key_padding_mask and mask, which is - our case. - Layers must be TransformerEncoderLayerGetWeights layers. """ - if src_key_padding_mask is not None: - _skpm_dtype = src_key_padding_mask.dtype - if _skpm_dtype != torch.bool and not \ - torch.is_floating_point(src_key_padding_mask): - raise AssertionError("only bool and floating types of " - "key_padding_mask are supported") + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(mask), + other_name="mask", + target_type=src.dtype + ) + + mask = F._canonical_mask( + mask=mask, + mask_name="mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + output = src + convert_to_nested = False + first_layer = self.layers[0] src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = '' + str_first_layer = "self.layers[0]" + batch_first = first_layer.self_attn.batch_first + if not hasattr(self, "use_nested_tensor"): + why_not_sparsity_fast_path = "use_nested_tensor attribute not present" + elif not self.use_nested_tensor: + why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True" + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif (((not hasattr(self, "mask_check")) or self.mask_check) + and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif src.device.type not in _supported_device_type: + why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}" + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) + src_key_padding_mask_for_layers = None + + seq_len = _get_seq_len(src, batch_first) + is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) + + # THIS IS THE MODIFIED PART sa_weights = [None] * len(self.layers) - for mod, i in zip(self.layers, range(len(self.layers))): output, sa_weights[i] = mod( - output, src_mask=mask, + output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers, + # New args: return_weights=return_weights, average_heads=average_heads) + # END OF MODIFIED PART + + if convert_to_nested: + output = output.to_padded_tensor(0., src.size()) if self.norm is not None: output = self.norm(output) @@ -90,30 +137,45 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, class ModifiedTransformerDecoder(TransformerDecoder): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) + + def __init__(self, decoder_layer, *args, **kw): + if not isinstance(decoder_layer, ModifiedTransformerDecoderLayer): + raise ValueError("Encoder layer should be of type {}. Got {}" + .format(ModifiedTransformerEncoderLayer.__name__, + type(decoder_layer))) + super().__init__(decoder_layer, *args, **kw) def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + # New args: return_weights=False, average_heads=False): """ Copy-pasted from torch. Now returns weights. - Layers must be TransformerDecoderLayerGetWeightsNoSOS layers. """ output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + # THIS IS THE MODIFIED PART mha_weights = [None] * len(self.layers) sa_weights = [None] * len(self.layers) - for mod, i in zip(self.layers, range(len(self.layers))): output, mha_weights[i], sa_weights[i] = \ mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + # New args: return_weights=return_weights, average_heads=average_heads) + # END OF MODIFIED PART if self.norm is not None: output = self.norm(output) @@ -121,160 +183,39 @@ def forward(self, tgt: Tensor, memory: Tensor, return output, sa_weights, mha_weights -def do_not_share_linear_weights(attn: MultiheadAttention, d_model): - """ - I added a request for this parameter to be accessible. - https://github.com/pytorch/pytorch/issues/92990 - """ - factory_kwargs = {'device': None, 'dtype': None} - - # Overriding some parameters in the self attention. - # Ugly but.... Torch does not have a parameter to NOT share linear - # weights. In their code, their only NOT share weights when dimensions - # are not the same. This is not our case. This is saved in their - # parameter _qkv_same_embed_dim. By changing this, we change their - # forward call to the MultiHeadAttention in self.self_attn. - attn._qkv_same_embed_dim = False - attn.q_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.k_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.v_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.register_parameter('in_proj_weight', None) - attn._reset_parameters() - - -class ModifiedTransformerEncoderLayer(TransformerEncoderLayer): - def __init__(self, d_model, nhead, **kw): - super().__init__(d_model, nhead, **kw) - - do_not_share_linear_weights(self.self_attn, d_model) - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Now returns weights. - Erased all the fast-track checks. - """ - x = src - if self.norm_first: - # Norm, SA, Add, Norm, FF, Add - sa, sa_weights = self._sa_block( - self.norm1(x), src_mask, src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + sa - x = x + self._ff_block(self.norm2(x)) - else: - # SA, Add, Norm, FF, Add, Norm - sa, sa_weights = self._sa_block( - x, src_mask, src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm1(x + sa) - x = self.norm2(x + self._ff_block(x)) - - return x, sa_weights - - # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - output = self.self_attn( - x, x, x, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) - x, weights = output # if return_weights is False, weights is None - - return self.dropout1(x), weights - - -class ModifiedTransformerDecoderLayer(TransformerDecoderLayer): - """ - Decoder Layer, in the case where we do not have a start of sequence (SOS) - token, and our mask contains only -inf for the first position. Output of - self-attention becomes nan after the softmax step. Setting to 0. - - Also, now returning attention weights. - """ - def __init__(self, d_model, nhead, **kw): - super().__init__(d_model, nhead, **kw) +class ModifiedTransformer(Transformer): + encoder: ModifiedTransformerEncoder + decoder: ModifiedTransformerDecoder - do_not_share_linear_weights(self.self_attn, d_model) - do_not_share_linear_weights(self.multihead_attn, d_model) + def __init__(self, *args, **kw): + super().__init__(*args, **kw) - def forward(self, tgt: Tensor, memory: Tensor, + def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None, tgt_mask: Tensor = None, memory_mask: Tensor = None, + src_key_padding_mask: Tensor = None, tgt_key_padding_mask: Tensor = None, memory_key_padding_mask: Tensor = None, + src_is_causal: bool = None, tgt_is_causal: bool = None, + memory_is_causal: bool = False, + # New args: return_weights=False, average_heads=False): """ - Copy-pasted from torch. Now returns weights + converts nan to 0. - Weights are None if return_weights is False. - """ - # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf - x = tgt - if self.norm_first: - # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add - sa, sa_weights = self._sa_block( - self.norm1(x), tgt_mask, tgt_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + sa - - mha, mha_weights = self._mha_block( - self.norm2(x), memory, memory_mask, memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + mha - x = x + self._ff_block(self.norm3(x)) - else: - # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm. - sa, sa_weights = self._sa_block( - x, tgt_mask, tgt_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm1(x + sa) - - mha, mha_weights = self._mha_block( - x, memory, memory_mask, memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm2(x + mha) - x = self.norm3(x + self._ff_block(x)) - - return x, mha_weights, sa_weights - - # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - """ Copy-pasted from torch. Now returns weights. """ - output = self.self_attn( - x, x, x, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) - - x, weights = output # If not return_weights, weights is None. - - return self.dropout1(x), weights - - # multihead attention block - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Can now use need_weight = True. - """ - output = self.multihead_attn( - x, mem, mem, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) + logger.debug("Entering main Transformer's forward.") + log_gpu_memory_usage(logger) + memory, sa_weights_encoder = self.encoder( + src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, + is_causal=src_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) - if return_weights: - x, weights = output - else: - x, weights = output, None + output, sa_weights_decoder, mha_weights = self.decoder( + tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) - return self.dropout2(x[0]), weights + return output, sa_weights_encoder, sa_weights_decoder, mha_weights diff --git a/requirements.txt b/requirements.txt index d492885c..7739110b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,10 +34,3 @@ nibabel==5.2.* numpy==1.23.* scipy==1.9.* scikit-image==0.22.* - - -# --------------- Notes to developers -# If we upgrade torch, verify if code copied in -# models.projects.transformers_from_torch has changed. -# (current code copied from torch 1.13.1) -# ----------