Skip to content

Commit

Permalink
Simplify argparser. ToDo: change option reverse. Clarify rescale.
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jan 17, 2024
1 parent 60443c2 commit 343f845
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 100 deletions.
18 changes: 17 additions & 1 deletion dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from dwi_ml.data.processing.streamlines.sos_eos_management import \
add_label_as_last_dim, convert_dirs_to_class
from dwi_ml.data.processing.streamlines.post_processing import compute_directions
from dwi_ml.data.processing.streamlines.post_processing import \
compute_directions
from dwi_ml.data.spheres import TorchSphere
from dwi_ml.models.embeddings import keys_to_embeddings
from dwi_ml.models.main_models import (ModelWithDirectionGetter,
Expand Down Expand Up @@ -993,3 +994,18 @@ def merge_batches_weights(self, weights, new_weights, device):
weights.extend(new_weights)
return (weights,)


def find_transformer_class(model_type: str):
"""
model_type: returned by verify_which_model_in_path.
"""
transformers_dict = {
OriginalTransformerModel.__name__: OriginalTransformerModel,
TransformerSrcAndTgtModel.__name__: TransformerSrcAndTgtModel,
TransformerSrcOnlyModel.__name: TransformerSrcOnlyModel
}
if model_type not in transformers_dict.keys():
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return transformers_dict[model_type]
24 changes: 4 additions & 20 deletions dwi_ml/models/projects/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def add_transformers_model_args(p):

gx = p.add_argument_group(
"Embedding of the input (X)",
"Input embedding size defines the d_model. The d_model must be divisible "
"by the number of heads.\n"
"Input embedding size defines the d_model. The d_model must be "
"divisible by the number of heads.\n"
"Note that for TTST, total d_model will rather be "
"input_embedded_size + target_embedded_size.\n")
AbstractTransformerModel.add_args_input_embedding(
Expand Down Expand Up @@ -56,7 +56,8 @@ def add_transformers_model_args(p):
gt.add_argument(
'--target_embedded_size', type=int, metavar='n',
help="Embedding size for targets (for TTST only). \n"
"Total d_model will be input_embedded_size + target_embedded_size.")
"Total d_model will be input_embedded_size + "
"target_embedded_size.")

gtt = p.add_argument_group(title='Transformer: main layers')
gtt.add_argument(
Expand Down Expand Up @@ -119,20 +120,3 @@ def add_transformers_model_args(p):

g = p.add_argument_group("Output")
AbstractTransformerModel.add_args_tracking_model(g)


def find_transformer_class(model_type):
"""
model_type: returned by verify_which_model_in_path.
"""
if model_type == 'OriginalTransformerModel':
model_cls = OriginalTransformerModel
elif model_type == 'TransformerSrcAndTgtModel':
model_cls = TransformerSrcAndTgtModel
elif model_type == 'TransformerSrcOnlyModel':
model_cls = TransformerSrcOnlyModel
else:
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return model_cls
113 changes: 50 additions & 63 deletions dwi_ml/testing/projects/tt_visu_main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
# -*- coding: utf-8 -*-

"""
Main part for the tt_visualize_weights, separated to be callable by the
jupyter notebook.
Main part for the tt_visualize_weights, separated to be callable by the jupyter
notebook.
"""
import argparse
import glob
import logging
import os
from typing import List

from dipy.io.streamline import save_tractogram
import numpy as np
import torch
from dipy.io.streamline import save_tractogram
from matplotlib import pyplot as plt

from scilpy.io.fetcher import get_home as get_scilpy_folder
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist, add_reference_arg)

from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path, add_memory_args, \
add_arg_existing_experiment_path
from dwi_ml.io_utils import (add_logging_arg, add_memory_args,
verify_which_model_in_path,
add_arg_existing_experiment_path)
from dwi_ml.models.projects.transformer_models import (
OriginalTransformerModel, TransformerSrcAndTgtModel)
from dwi_ml.models.projects.transformers_utils import find_transformer_class
OriginalTransformerModel, TransformerSrcAndTgtModel,
find_transformer_class)
from dwi_ml.testing.projects.tt_visu_bertviz import (
encoder_decoder_show_head_view, encoder_decoder_show_model_view,
encoder_show_model_view, encoder_show_head_view)
Expand All @@ -33,32 +32,11 @@
from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow
from dwi_ml.testing.projects.tt_visu_utils import (
prepare_encoder_tokens, prepare_decoder_tokens,
reshape_attention_to4d_tocpu, unpad_rescale_attention, resample_attention_one_line)
reshape_attention_to4d_tocpu, unpad_rescale_attention,
resample_attention_one_line)
from dwi_ml.testing.utils import add_args_testing_subj_hdf5


def set_out_dir_visu_weights_and_create_if_not_exists(args):
if args.out_dir is None:
args.out_dir = os.path.join(args.experiment_path, 'visu_weights')
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)

return args


def get_config_filename():
"""
File that will be saved by the python script with all the args. The
jupyter notebook can then load them again.
"""
# We choose to add it in the hidden .scilpy folder in our home.
# (Where our test data also is).
hidden_folder = get_scilpy_folder()
config_filename = os.path.join(
hidden_folder, 'ipynb_tt_visualize_weights.config')
return config_filename


def build_argparser_transformer_visu():
"""
This needs to be in a module, to be imported in the jupyter notebook. Do
Expand All @@ -68,27 +46,29 @@ def build_argparser_transformer_visu():
formatter_class=argparse.RawTextHelpFormatter,
description=__doc__)

# Main dataset arguments:
add_arg_existing_experiment_path(p)
add_args_testing_subj_hdf5(p, ask_input_group=True)

p.add_argument('in_sft',
help="A small tractogram; a bundle of streamlines whose "
"attention mask we will average.")
help="A small tractogram; a bundle of streamlines that "
"should be \nuniformized. Else, see option "
"--align_endpoints")
p.add_argument(
'--out_prefix', metavar='name',
help="Prefix of the all output files. Do not include a path. "
"Suffixes are: \n"
" 1) 'as_matrix': tt_matrix_[encoder|decoder|cross].png.\n"
" 2) 'bertviz': tt_bertviz.html, tt_bertviz.ipynb, "
"tt_bertviz.config.\n"
" 3) 'colored_sft': colored_sft.trk."
" 3) 'colored_sft': colored_sft.trk.\n"
" 4) 'bertviz_locally': None")
p.add_argument(
'--out_dir', metavar='d',
help="Output directory where to save the output files.\n"
"Default: experiment_path/visu_weights")

p.add_argument(
g = p.add_argument_group("Visualization options")
g.add_argument(
'--visu_type', required=True, nargs='+',
choices=['as_matrix', 'bertviz', 'colored_sft', 'bertviz_locally'],
help="Output option. Choose any number (at least one). \n"
Expand All @@ -99,34 +79,31 @@ def build_argparser_transformer_visu():
" Will create a html file that can be viewed (see "
"--out_dir)\n"
" 3) 'colored_sft': Save a colored sft.\n"
" 4) 'bertviz_locally': Run the bertviz without using jupyter "
"(debug purposes).\n"
" Output will not not show, but html stuff will print in "
"the terminal.")
p.add_argument(
'--rescale', action='store_true',
help="If true, rescale to max 1 per row.")

g = p.add_mutually_exclusive_group()
" 4) 'bertviz_locally': Run the bertviz without using jupyter\n"
" (Debugging purposes. Output will not not show, but html\n"
" stuff will print in the terminal.")
g.add_argument('--rescale', action='store_true',
help="If true, rescale to max 1 per row.")
g.add_argument('--align_endpoints', action='store_true',
help="If set, align endpoints of the batch. Either this "
"or --inverse_align_endpoints. \nProbably helps"
"visualisation with option --visu_type 'colored_sft'.")
g.add_argument('--inverse_align_endpoints', action='store_true',
help="If set, aligns endpoints and then reverses the "
"bundle.")
p.add_argument('--resample_attention', type=int,
help="Streamline will be sampled as decided by the model."
"However, \nattention will be resampled to fit better "
"in the html page. \n (Resampling is done by "
"averaging the attention every N points).\n"
"(only for bertviz and as_matrix")
p.add_argument('--average_heads', action='store_true',
help="If set, try aligning endpoints of the sft. Will use "
"the automatic \nalignment. For more options, align "
"your streamlines first, using \n"
" >> scil_tractogram_uniformize_endpoints.py.\n")
g.add_argument('--reverse', action='store_true',
help="If set, reverses all streamlines first.\n"
"(With option --align_endpoints, reversing is done "
"after.)")
g.add_argument('--resample_matrix', type=int, metavar='n',
help="Streamlines will be sampled (nb points) as decided "
"by the model. \nHowever, attention shown as a matrix "
"can be resampled \nto better fit in the html page.")
g.add_argument('--average_heads', action='store_true',
help="If true, resample all heads (per layer per "
"attention type).")

p.add_argument('--batch_size', type=int)
add_memory_args(p)
g = add_memory_args(p)
g.add_argument('--batch_size', metavar='s', type=int,
help="The batch size in number of streamlines.")

p.add_argument('--show_now', action='store_true',
help="If set, shows the matrices on screen. Else, only "
Expand All @@ -138,6 +115,16 @@ def build_argparser_transformer_visu():
return p


def create_out_dir_visu_weights(args):
# Define out_dir as experiment_path/visu_weights if not defined.
# Create it if it does not exist.
if args.out_dir is None:
args.out_dir = os.path.join(args.experiment_path, 'visu_weights')
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)
return args


def tt_visualize_weights_main(args, parser):
"""
Main part of the script: verifies with type of Transformer we have,
Expand Down Expand Up @@ -166,7 +153,7 @@ def tt_visualize_weights_main(args, parser):
parser.error("Experiment {} not found.".format(args.experiment_path))

# Out files: jupyter stuff already managed in main script. Remains the sft.
args = set_out_dir_visu_weights_and_create_if_not_exists(args)
args = create_out_dir_visu_weights(args)
out_files = []
out_sft = None
prefix_total = os.path.join(args.out_dir, args.out_prefix)
Expand Down Expand Up @@ -244,7 +231,7 @@ def tt_visualize_weights_main(args, parser):
else: # TransformerSrcOnlyModel
visu_fct = visu_encoder_only

visu_fct(weights, sft, args.resample_attention, args.rescale,
visu_fct(weights, sft, args.resample_matrix, args.rescale,
model.direction_getter.add_eos, save_colored_sft,
run_bertviz, show_as_matrices, colored_sft_name=out_sft,
matrices_prefix=prefix_total)
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def add_args_testing_subj_hdf5(p, ask_input_group=False,
p.add_argument('--subset', default='testing',
choices=['training', 'validation', 'testing'],
help="Subject id should probably come come the "
"'testing' set but you can\n modify this to "
"'testing' set but you can \nmodify this to "
"'training' or 'validation'.")


Expand Down
45 changes: 30 additions & 15 deletions scripts_python/tt_visualize_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,55 @@
import sys
from os.path import dirname

from scilpy.io.fetcher import get_home as get_scilpy_folder
from scilpy.io.utils import assert_outputs_exist

from dwi_ml.testing.projects.tt_visu_main import (
build_argparser_transformer_visu, get_config_filename,
tt_visualize_weights_main, set_out_dir_visu_weights_and_create_if_not_exists)
build_argparser_transformer_visu, create_out_dir_visu_weights,
tt_visualize_weights_main)


# Note. To use through jupyter, the file
# must also be up-to-date.
# Note. To use through jupyter, the file must also be up-to-date.
# To modify it as developers:
# jupyter notebook dwi_ml/testing/projects/tt_visualise_weights.ipynb
# Do not run it there though! Else, saving it will save the outputs too!


def _get_config_filename():
"""
File that will be saved by the python script with all the args. The
jupyter notebook can then load them again.
"""
# We choose to add it in the hidden .scilpy folder in our home.
# (Where our test data also is).
hidden_folder = get_scilpy_folder()
config_filename = os.path.join(
hidden_folder, 'ipynb_tt_visualize_weights.config')
return config_filename


def main():
# Getting the raw args. Will be sent to jupyter.
argv = sys.argv

parser = build_argparser_transformer_visu()
args = parser.parse_args()

run_locally = False
if 'bertviz_locally' in args.visu_type:
if 'bertviz' in args.visu_type:
parser.error("Please choose either 'bertviz' or "
"'bertviz_locally', not both.")

# Verifying if jupyter is required.
if 'bertviz' in args.visu_type and 'bertviz_locally' in args.visu_type:
raise ValueError("Please only select 'bertviz' or 'bertviz_locally', "
"not both.")
elif 'bertviz_locally' in args.visu_type:
print("--DEBUGGING MODE--\n"
"We will run the bertviz but it will not save any output!")
run_locally = True
if 'bertviz' not in args.visu_type:
elif 'bertviz' in args.visu_type:
print("Preparing to run through jupyter.")
run_locally = False
else:
run_locally = True

# Running.
if run_locally:
tt_visualize_weights_main(args, parser)
else:
Expand All @@ -54,7 +70,7 @@ def main():
.format(raw_ipynb_filename))

# 2) Verify that output dir exists but not the html output files.
args = set_out_dir_visu_weights_and_create_if_not_exists(args)
args = create_out_dir_visu_weights(args)

out_html_filename = args.out_prefix + 'tt_bertviz.html'
out_html_file = os.path.join(args.out_dir, out_html_filename)
Expand All @@ -69,9 +85,8 @@ def main():
# by the notebook in a new argparse instance.
# Jupyter notebook needs to know where to load the config file.
# Needs to be always at the same place because we cannot send an
# argument to jupyter. Or, we could ask it here and tell user to
# add it manually inside the jupyter notebook... complicated.
hidden_config_filename = get_config_filename()
# argument to jupyter.
hidden_config_filename = _get_config_filename()
if os.path.isfile(hidden_config_filename):
# In case a previous call failed.
os.remove(hidden_config_filename)
Expand Down

0 comments on commit 343f845

Please sign in to comment.