Skip to content

Commit

Permalink
Copy Transformer model from latest version. New is_causal arg
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed May 22, 2024
1 parent 6266c2a commit d3c552f
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 217 deletions.
43 changes: 31 additions & 12 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,17 @@ def __init__(self, **kw):
self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size,
dropout=self.dropout_rate, activation=self.activation,
batch_first=True, norm_first=self.norm_first)
self.modified_torch_transformer = ModifiedTransformerEncoder(
main_layer_encoder, self.n_layers_e, norm=None)

# Receiving weird warning: enable_nested_tensor is True,
# but self.use_nested_tensor is False because encoder_layer.norm_first
# was True.
enable_nested = False if self.norm_first else True

# Note about norm: this is a final normalization step. Not linked to
# the normalization decided with self.norm_first.
self.transformer_encoder_only = ModifiedTransformerEncoder(
main_layer_encoder, self.n_layers_e, norm=None,
enable_nested_tensor=enable_nested)

@property
def d_model(self):
Expand Down Expand Up @@ -611,9 +620,9 @@ def _run_main_layer_forward(self, inputs, masks, return_weights):
# Encoder only.

# mask_future, mask_padding = masks
outputs, sa_weights = self.modified_torch_transformer(
outputs, sa_weights = self.transformer_encoder_only(
src=inputs, mask=masks[0], src_key_padding_mask=masks[1],
return_weights=return_weights)
is_causal=True, return_weights=return_weights)

return outputs, (sa_weights,)

Expand Down Expand Up @@ -844,8 +853,17 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw):
dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate,
activation=self.activation, batch_first=True,
norm_first=self.norm_first)
encoder = ModifiedTransformerEncoder(encoder_layer, self.n_layers_e,
norm=None)

# Receiving weird warning: enable_nested_tensor is True,
# but self.use_nested_tensor is False because encoder_layer.norm_first
# was True.
enable_nested = False if self.norm_first else True

# Note about norm: this is a final normalization step. Not linked to
# the normalization decided with self.norm_first.
encoder = ModifiedTransformerEncoder(
encoder_layer, self.n_layers_e, norm=None,
enable_nested_tensor=enable_nested)

# Decoder
decoder_layer = ModifiedTransformerDecoderLayer(
Expand All @@ -856,7 +874,7 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw):
decoder = ModifiedTransformerDecoder(decoder_layer, n_layers_d,
norm=None)

self.modified_torch_transformer = ModifiedTransformer(
self.transformer = ModifiedTransformer(
self.d_model, self.nheads, self.n_layers_e, n_layers_d,
self.ffnn_hidden_size, self.dropout_rate, self.activation,
encoder, decoder, batch_first=True,
Expand Down Expand Up @@ -904,11 +922,12 @@ def _run_main_layer_forward(self, data, masks, return_weights):
# embed_x, embed_t = data
# mask_future, mask_padding = masks
outputs, sa_weights_encoder, sa_weights_decoder, mha_weights = \
self.modified_torch_transformer(
self.transformer(
src=data[0], tgt=data[1],
src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0],
src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1],
memory_key_padding_mask=masks[1],
memory_key_padding_mask=masks[1], src_is_causal=True,
tgt_is_causal=True, memory_is_causal=True,
return_weights=return_weights)
return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights)

Expand Down Expand Up @@ -955,7 +974,7 @@ def __init__(self, **kw):
self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size,
dropout=self.dropout_rate, activation=self.activation,
batch_first=True, norm_first=self.norm_first)
self.modified_torch_transformer = ModifiedTransformerEncoder(
self.transformer_encoder_only = ModifiedTransformerEncoder(
main_layer_encoder, self.n_layers_e, norm=None)

@property
Expand Down Expand Up @@ -987,9 +1006,9 @@ def _run_main_layer_forward(self, concat_s_t, masks, return_weights):
# Encoder only.

# mask_future, mask_padding = masks
outputs, sa_weights = self.modified_torch_transformer(
outputs, sa_weights = self.transformer_encoder_only(
src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1],
return_weights=return_weights)
is_causal=True, return_weights=return_weights)

return outputs, (sa_weights,)

Expand Down
Loading

0 comments on commit d3c552f

Please sign in to comment.