Skip to content

Commit

Permalink
Simplify argparser. Adding histogram. ToDo: change option reverse. Cl…
Browse files Browse the repository at this point in the history
…arify rescale.
  • Loading branch information
EmmaRenauld committed Feb 2, 2024
1 parent f76ef69 commit fc63850
Show file tree
Hide file tree
Showing 10 changed files with 522 additions and 407 deletions.
2 changes: 1 addition & 1 deletion dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
help="If set, use GPU for processing. Cannot be used together "
"with --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
Expand Down
18 changes: 17 additions & 1 deletion dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from dwi_ml.data.processing.streamlines.sos_eos_management import \
add_label_as_last_dim, convert_dirs_to_class
from dwi_ml.data.processing.streamlines.post_processing import compute_directions
from dwi_ml.data.processing.streamlines.post_processing import \
compute_directions
from dwi_ml.data.spheres import TorchSphere
from dwi_ml.models.embeddings import keys_to_embeddings
from dwi_ml.models.main_models import (ModelWithDirectionGetter,
Expand Down Expand Up @@ -993,3 +994,18 @@ def merge_batches_weights(self, weights, new_weights, device):
weights.extend(new_weights)
return (weights,)


def find_transformer_class(model_type: str):
"""
model_type: returned by verify_which_model_in_path.
"""
transformers_dict = {
OriginalTransformerModel.__name__: OriginalTransformerModel,
TransformerSrcAndTgtModel.__name__: TransformerSrcAndTgtModel,
TransformerSrcOnlyModel.__name: TransformerSrcOnlyModel
}
if model_type not in transformers_dict.keys():
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return transformers_dict[model_type]
24 changes: 4 additions & 20 deletions dwi_ml/models/projects/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def add_transformers_model_args(p):

gx = p.add_argument_group(
"Embedding of the input (X)",
"Input embedding size defines the d_model. The d_model must be divisible "
"by the number of heads.\n"
"Input embedding size defines the d_model. The d_model must be "
"divisible by the number of heads.\n"
"Note that for TTST, total d_model will rather be "
"input_embedded_size + target_embedded_size.\n")
AbstractTransformerModel.add_args_input_embedding(
Expand Down Expand Up @@ -56,7 +56,8 @@ def add_transformers_model_args(p):
gt.add_argument(
'--target_embedded_size', type=int, metavar='n',
help="Embedding size for targets (for TTST only). \n"
"Total d_model will be input_embedded_size + target_embedded_size.")
"Total d_model will be input_embedded_size + "
"target_embedded_size.")

gtt = p.add_argument_group(title='Transformer: main layers')
gtt.add_argument(
Expand Down Expand Up @@ -119,20 +120,3 @@ def add_transformers_model_args(p):

g = p.add_argument_group("Output")
AbstractTransformerModel.add_args_tracking_model(g)


def find_transformer_class(model_type):
"""
model_type: returned by verify_which_model_in_path.
"""
if model_type == 'OriginalTransformerModel':
model_cls = OriginalTransformerModel
elif model_type == 'TransformerSrcAndTgtModel':
model_cls = TransformerSrcAndTgtModel
elif model_type == 'TransformerSrcOnlyModel':
model_cls = TransformerSrcOnlyModel
else:
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return model_cls
113 changes: 50 additions & 63 deletions dwi_ml/testing/projects/tt_visu_main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
# -*- coding: utf-8 -*-

"""
Main part for the tt_visualize_weights, separated to be callable by the
jupyter notebook.
Main part for the tt_visualize_weights, separated to be callable by the jupyter
notebook.
"""
import argparse
import glob
import logging
import os
from typing import List

from dipy.io.streamline import save_tractogram
import numpy as np
import torch
from dipy.io.streamline import save_tractogram
from matplotlib import pyplot as plt

from scilpy.io.fetcher import get_home as get_scilpy_folder
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist, add_reference_arg)

from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path, add_memory_args, \
add_arg_existing_experiment_path
from dwi_ml.io_utils import (add_logging_arg, add_memory_args,
verify_which_model_in_path,
add_arg_existing_experiment_path)
from dwi_ml.models.projects.transformer_models import (
OriginalTransformerModel, TransformerSrcAndTgtModel)
from dwi_ml.models.projects.transformers_utils import find_transformer_class
OriginalTransformerModel, TransformerSrcAndTgtModel,
find_transformer_class)
from dwi_ml.testing.projects.tt_visu_bertviz import (
encoder_decoder_show_head_view, encoder_decoder_show_model_view,
encoder_show_model_view, encoder_show_head_view)
Expand All @@ -33,32 +32,11 @@
from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow
from dwi_ml.testing.projects.tt_visu_utils import (
prepare_encoder_tokens, prepare_decoder_tokens,
reshape_attention_to4d_tocpu, unpad_rescale_attention, resample_attention_one_line)
reshape_attention_to4d_tocpu, unpad_rescale_attention,
resample_attention_one_line)
from dwi_ml.testing.utils import add_args_testing_subj_hdf5


def set_out_dir_visu_weights_and_create_if_not_exists(args):
if args.out_dir is None:
args.out_dir = os.path.join(args.experiment_path, 'visu_weights')
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)

return args


def get_config_filename():
"""
File that will be saved by the python script with all the args. The
jupyter notebook can then load them again.
"""
# We choose to add it in the hidden .scilpy folder in our home.
# (Where our test data also is).
hidden_folder = get_scilpy_folder()
config_filename = os.path.join(
hidden_folder, 'ipynb_tt_visualize_weights.config')
return config_filename


def build_argparser_transformer_visu():
"""
This needs to be in a module, to be imported in the jupyter notebook. Do
Expand All @@ -68,27 +46,29 @@ def build_argparser_transformer_visu():
formatter_class=argparse.RawTextHelpFormatter,
description=__doc__)

# Main dataset arguments:
add_arg_existing_experiment_path(p)
add_args_testing_subj_hdf5(p, ask_input_group=True)

p.add_argument('in_sft',
help="A small tractogram; a bundle of streamlines whose "
"attention mask we will average.")
help="A small tractogram; a bundle of streamlines that "
"should be \nuniformized. Else, see option "
"--align_endpoints")
p.add_argument(
'--out_prefix', metavar='name',
help="Prefix of the all output files. Do not include a path. "
"Suffixes are: \n"
" 1) 'as_matrix': tt_matrix_[encoder|decoder|cross].png.\n"
" 2) 'bertviz': tt_bertviz.html, tt_bertviz.ipynb, "
"tt_bertviz.config.\n"
" 3) 'colored_sft': colored_sft.trk."
" 3) 'colored_sft': colored_sft.trk.\n"
" 4) 'bertviz_locally': None")
p.add_argument(
'--out_dir', metavar='d',
help="Output directory where to save the output files.\n"
"Default: experiment_path/visu_weights")

p.add_argument(
g = p.add_argument_group("Visualization options")
g.add_argument(
'--visu_type', required=True, nargs='+',
choices=['as_matrix', 'bertviz', 'colored_sft', 'bertviz_locally'],
help="Output option. Choose any number (at least one). \n"
Expand All @@ -99,34 +79,31 @@ def build_argparser_transformer_visu():
" Will create a html file that can be viewed (see "
"--out_dir)\n"
" 3) 'colored_sft': Save a colored sft.\n"
" 4) 'bertviz_locally': Run the bertviz without using jupyter "
"(debug purposes).\n"
" Output will not not show, but html stuff will print in "
"the terminal.")
p.add_argument(
'--rescale', action='store_true',
help="If true, rescale to max 1 per row.")

g = p.add_mutually_exclusive_group()
" 4) 'bertviz_locally': Run the bertviz without using jupyter\n"
" (Debugging purposes. Output will not not show, but html\n"
" stuff will print in the terminal.")
g.add_argument('--rescale', action='store_true',
help="If true, rescale to max 1 per row.")
g.add_argument('--align_endpoints', action='store_true',
help="If set, align endpoints of the batch. Either this "
"or --inverse_align_endpoints. \nProbably helps"
"visualisation with option --visu_type 'colored_sft'.")
g.add_argument('--inverse_align_endpoints', action='store_true',
help="If set, aligns endpoints and then reverses the "
"bundle.")
p.add_argument('--resample_attention', type=int,
help="Streamline will be sampled as decided by the model."
"However, \nattention will be resampled to fit better "
"in the html page. \n (Resampling is done by "
"averaging the attention every N points).\n"
"(only for bertviz and as_matrix")
p.add_argument('--average_heads', action='store_true',
help="If set, try aligning endpoints of the sft. Will use "
"the automatic \nalignment. For more options, align "
"your streamlines first, using \n"
" >> scil_tractogram_uniformize_endpoints.py.\n")
g.add_argument('--reverse', action='store_true',
help="If set, reverses all streamlines first.\n"
"(With option --align_endpoints, reversing is done "
"after.)")
g.add_argument('--resample_matrix', type=int, metavar='n',
help="Streamlines will be sampled (nb points) as decided "
"by the model. \nHowever, attention shown as a matrix "
"can be resampled \nto better fit in the html page.")
g.add_argument('--average_heads', action='store_true',
help="If true, resample all heads (per layer per "
"attention type).")

p.add_argument('--batch_size', type=int)
add_memory_args(p)
g = add_memory_args(p)
g.add_argument('--batch_size', metavar='s', type=int,
help="The batch size in number of streamlines.")

p.add_argument('--show_now', action='store_true',
help="If set, shows the matrices on screen. Else, only "
Expand All @@ -138,6 +115,16 @@ def build_argparser_transformer_visu():
return p


def create_out_dir_visu_weights(args):
# Define out_dir as experiment_path/visu_weights if not defined.
# Create it if it does not exist.
if args.out_dir is None:
args.out_dir = os.path.join(args.experiment_path, 'visu_weights')
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)
return args


def tt_visualize_weights_main(args, parser):
"""
Main part of the script: verifies with type of Transformer we have,
Expand Down Expand Up @@ -166,7 +153,7 @@ def tt_visualize_weights_main(args, parser):
parser.error("Experiment {} not found.".format(args.experiment_path))

# Out files: jupyter stuff already managed in main script. Remains the sft.
args = set_out_dir_visu_weights_and_create_if_not_exists(args)
args = create_out_dir_visu_weights(args)
out_files = []
out_sft = None
prefix_total = os.path.join(args.out_dir, args.out_prefix)
Expand Down Expand Up @@ -244,7 +231,7 @@ def tt_visualize_weights_main(args, parser):
else: # TransformerSrcOnlyModel
visu_fct = visu_encoder_only

visu_fct(weights, sft, args.resample_attention, args.rescale,
visu_fct(weights, sft, args.resample_matrix, args.rescale,
model.direction_getter.add_eos, save_colored_sft,
run_bertviz, show_as_matrices, colored_sft_name=out_sft,
matrices_prefix=prefix_total)
Expand Down
Loading

0 comments on commit fc63850

Please sign in to comment.