Skip to content

Commit

Permalink
Simplyfying argparsers. Adding histogram. Tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jan 18, 2024
1 parent 4afa173 commit 671ff6b
Show file tree
Hide file tree
Showing 20 changed files with 372 additions and 375 deletions.
7 changes: 5 additions & 2 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def add_arg_existing_experiment_path(p: ArgumentParser):
help='Path to the directory containing the experiment.\n'
'(Should contain a model subdir with a file \n'
'parameters.json and a file best_model_state.pkl.)')
p.add_argument('--use_latest_epoch', action='store_true',
help="If true, use model at latest epoch rather than "
"default (best model).")


def add_memory_args(p: ArgumentParser, add_lazy_options=False,
Expand All @@ -43,8 +46,8 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
"with --processes.")
help="If set, use GPU for processing. Cannot be used together "
"with \noption --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
help="If set, use GPU for processing.")
Expand Down
16 changes: 16 additions & 0 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,19 @@ def merge_batches_weights(self, weights, new_weights, device):
if weights is None:
weights = (None,)
return (merge_one_weight_type(weights[0], new_weights[0], device), )


def find_transformer_class(model_type: str):
"""
model_type: returned by verify_which_model_in_path.
"""
transformers_dict = {
OriginalTransformerModel.__name__: OriginalTransformerModel,
TransformerSrcAndTgtModel.__name__: TransformerSrcAndTgtModel,
TransformerSrcOnlyModel.__name__: TransformerSrcOnlyModel
}
if model_type not in transformers_dict.keys():
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return transformers_dict[model_type]
28 changes: 5 additions & 23 deletions dwi_ml/models/projects/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
from dwi_ml.models.positional_encoding import (
keys_to_positional_encodings)
from dwi_ml.models.projects.transformer_models import (
AbstractTransformerModel, OriginalTransformerModel,
TransformerSrcAndTgtModel, TransformerSrcOnlyModel)
from dwi_ml.models.projects.transformer_models import AbstractTransformerModel

sphere_choices = ['symmetric362', 'symmetric642', 'symmetric724',
'repulsion724', 'repulsion100', 'repulsion200']
Expand All @@ -26,8 +24,8 @@ def add_transformers_model_args(p):

gx = p.add_argument_group(
"Embedding of the input (X)",
"Input embedding size defines the d_model. The d_model must be divisible "
"by the number of heads.\n"
"Input embedding size defines the d_model. The d_model must be "
"divisible by the number of heads.\n"
"Note that for TTST, total d_model will rather be "
"input_embedded_size + target_embedded_size.\n")
AbstractTransformerModel.add_args_input_embedding(
Expand Down Expand Up @@ -56,7 +54,8 @@ def add_transformers_model_args(p):
gt.add_argument(
'--target_embedded_size', type=int, metavar='n',
help="Embedding size for targets (for TTST only). \n"
"Total d_model will be input_embedded_size + target_embedded_size.")
"Total d_model will be input_embedded_size + "
"target_embedded_size.")

gtt = p.add_argument_group(title='Transformer: main layers')
gtt.add_argument(
Expand Down Expand Up @@ -119,20 +118,3 @@ def add_transformers_model_args(p):

g = p.add_argument_group("Output")
AbstractTransformerModel.add_args_tracking_model(g)


def find_transformer_class(model_type):
"""
model_type: returned by verify_which_model_in_path.
"""
if model_type == 'OriginalTransformerModel':
model_cls = OriginalTransformerModel
elif model_type == 'TransformerSrcAndTgtModel':
model_cls = TransformerSrcAndTgtModel
elif model_type == 'TransformerSrcOnlyModel':
model_cls = TransformerSrcOnlyModel
else:
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return model_cls
6 changes: 5 additions & 1 deletion dwi_ml/testing/projects/tt_visu_bertviz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
from typing import List, Tuple

# Ref: # https://github.com/jessevig/bertviz
from bertviz import model_view, head_view
Expand Down Expand Up @@ -49,6 +48,11 @@ def print_neuron_view_help():
def encoder_decoder_show_head_view(
encoder_attention, decoder_attention, cross_attention,
encoder_tokens, decoder_tokens):
"""
Expecting attentions of shape:
A list: nb_layers x
[nb_streamlines, nheads, batch_max_len, batch_max_len]
"""
print_head_view_help()
head_view(encoder_attention=encoder_attention,
decoder_attention=decoder_attention,
Expand Down
Loading

0 comments on commit 671ff6b

Please sign in to comment.