diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 9486cd39..46d1cf70 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import logging -from time import time from typing import Union, List, Tuple, Optional from dipy.data import get_sphere @@ -11,7 +10,8 @@ from dwi_ml.data.processing.streamlines.sos_eos_management import \ add_label_as_last_dim, convert_dirs_to_class -from dwi_ml.data.processing.streamlines.post_processing import compute_directions +from dwi_ml.data.processing.streamlines.post_processing import \ + compute_directions from dwi_ml.data.spheres import TorchSphere from dwi_ml.models.embeddings import keys_to_embeddings from dwi_ml.models.main_models import (ModelWithDirectionGetter, @@ -78,6 +78,29 @@ def pad_and_stack_batch(data: List[torch.Tensor], pad_first: bool, return torch.stack(data) +def merge_one_weight_type(weights, new_weights, device): + # Weight is a list per layer of tensors of shape + # nb_streamlines, nb_heads, batch_max_len, batch_max_len + new_weights = [layer_weight.to(device) for layer_weight in new_weights] + new_max_len = new_weights[0].shape[2] + + if weights is None: + return new_weights + else: + old_max_len = weights[0].shape[2] + + # Padding if necessary. We could pad to max_len, but probably + # heavy for no reason. + pad_w = max(0, new_max_len - old_max_len) + pad_n = max(0, old_max_len - new_max_len) + weights = [torch.cat(( + pad(w, (0, pad_w, 0, pad_w)), + pad(n, (0, pad_n, 0, pad_n))), + dim=0) for w, n in zip(weights, new_weights)] + + return weights + + class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter, ModelOneInputWithEmbedding): """ @@ -332,8 +355,7 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len): return mask_future, mask_padding def forward(self, inputs: List[torch.tensor], - input_streamlines: List[torch.tensor] = None, - average_heads=False): + input_streamlines: List[torch.tensor] = None): """ Params ------ @@ -351,9 +373,6 @@ def forward(self, inputs: List[torch.tensor], adequately masked to hide future positions. The last direction is not used. - As target during training. The whole sequence is used. - average_heads: bool - If return_weights, you may choose to average the weights from - different heads together. Returns ------- @@ -394,12 +413,7 @@ def forward(self, inputs: List[torch.tensor], use_padding = not np.all(input_lengths == input_lengths[0]) batch_max_len = np.max(input_lengths) if CLEAR_CACHE: - now = time() - logging.debug("Transformer: Maximal length in batch is {}" - .format(batch_max_len)) torch.torch.cuda.empty_cache() - now2 = time() - logging.debug("Cleared cache in {} secs.".format(now2 - now)) # ----------- Prepare masks masks = self._prepare_masks(input_lengths, use_padding, batch_max_len) @@ -422,7 +436,7 @@ def forward(self, inputs: List[torch.tensor], # 2. Main transformer outputs, weights = self._run_main_layer_forward( - data, masks, return_weights, average_heads) + data, masks, return_weights) # Here, data = one tensor, padded. # Unpad now and either @@ -485,6 +499,9 @@ def forward(self, inputs: List[torch.tensor], outputs = list(torch.split(outputs, list(input_lengths))) if return_weights: + # Padding weights to max length, else we won't be able to stack + # outputs. This way, all weights are a list, per layer, of + # tensors of shape [nb_streamlines, nb_heads, max_len, max_len] return outputs, weights return outputs @@ -498,8 +515,7 @@ def _run_embeddings(self, data, use_padding, batch_max_len): def _run_position_encoding(self, data): raise NotImplementedError - def _run_main_layer_forward(self, data, masks, return_weights, - average_heads): + def _run_main_layer_forward(self, data, masks, return_weights): raise NotImplementedError def _run_input_embedding(self, inputs, use_padding, batch_max_len): @@ -522,6 +538,7 @@ def merge_batches_outputs(self, all_outputs, new_batch, device=None): outputs, weights = None, None else: outputs, weights = all_outputs + new_outputs = super().merge_batches_outputs(outputs, new_outputs, device) new_weights = self.merge_batches_weights(weights, new_weights, @@ -582,26 +599,21 @@ def _run_position_encoding(self, inputs): inputs = self.dropout(inputs) return inputs - def _run_main_layer_forward(self, inputs, masks, - return_weights, average_heads): + def _run_main_layer_forward(self, inputs, masks, return_weights): # Encoder only. # mask_future, mask_padding = masks outputs, sa_weights = self.modified_torch_transformer( src=inputs, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights, average_heads=average_heads) + return_weights=return_weights) return outputs, (sa_weights,) def merge_batches_weights(self, weights, new_weights, device): - # weights is a single attention tensor (encoder): a tuple of 1. - new_weights = [a.to(device) for a in new_weights[0]] - + # Weights is a single attention tensor (encoder): a tuple of 1. if weights is None: - return (new_weights,) - else: - weights.extend(new_weights) - return (weights,) + weights = (None,) + return (merge_one_weight_type(weights[0], new_weights[0], device), ) class AbstractTransformerModelWithTarget(AbstractTransformerModel): @@ -703,8 +715,7 @@ def _run_embeddings(self, data, use_padding, batch_max_len): def _run_position_encoding(self, data): raise NotImplementedError - def _run_main_layer_forward(self, data, masks, return_weights, - average_heads): + def _run_main_layer_forward(self, data, masks, return_weights): raise NotImplementedError def format_prev_dir_(self, dirs): @@ -871,8 +882,7 @@ def _run_position_encoding(self, data): return inputs, targets - def _run_main_layer_forward(self, data, masks, - return_weights, average_heads): + def _run_main_layer_forward(self, data, masks, return_weights): """Original Main transformer Returns @@ -890,24 +900,15 @@ def _run_main_layer_forward(self, data, masks, src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0], src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1], memory_key_padding_mask=masks[1], - return_weights=return_weights, average_heads=average_heads) + return_weights=return_weights) return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights) def merge_batches_weights(self, weights, new_weights, device): - # weights is a Tuple[encoder, decoder, cross] - new_weights_e, new_weights_d, new_weights_c = new_weights - new_weights_e = [a.to(device) for a in new_weights_e] - new_weights_d = [a.to(device) for a in new_weights_d] - new_weights_c = [a.to(device) for a in new_weights_c] - if weights is None: - return new_weights_e, new_weights_d, new_weights_c - else: - weights_e, weights_d, weights_c = weights - weights_e.extend(new_weights_e) - weights_d.extend(new_weights_d) - weights_c.extend(new_weights_c) - return weights_e, weights_d, weights_c + weights = (None, None, None) + return (merge_one_weight_type(weights[0], new_weights[0], device), + merge_one_weight_type(weights[1], new_weights[1], device), + merge_one_weight_type(weights[1], new_weights[1], device)) class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget): @@ -972,24 +973,18 @@ def _run_position_encoding(self, data): data = self.dropout(data) return data - def _run_main_layer_forward(self, concat_s_t, masks, - return_weights, average_heads): + def _run_main_layer_forward(self, concat_s_t, masks, return_weights): # Encoder only. # mask_future, mask_padding = masks outputs, sa_weights = self.modified_torch_transformer( src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights, average_heads=average_heads) + return_weights=return_weights) return outputs, (sa_weights,) def merge_batches_weights(self, weights, new_weights, device): - # weights is a single attention tensor (encoder): a tuple of 1. - new_weights = [a.to(device) for a in new_weights[0]] - + # Weights is a single attention tensor (encoder): a tuple of 1. if weights is None: - return (new_weights,) - else: - weights.extend(new_weights) - return (weights,) - + weights = (None,) + return (merge_one_weight_type(weights[0], new_weights[0], device), ) diff --git a/dwi_ml/testing/projects/tt_visu_colored_sft.py b/dwi_ml/testing/projects/tt_visu_colored_sft.py index c009eb9b..abd4040b 100644 --- a/dwi_ml/testing/projects/tt_visu_colored_sft.py +++ b/dwi_ml/testing/projects/tt_visu_colored_sft.py @@ -1,13 +1,16 @@ # -*- coding: utf-8 -*- -import logging from typing import Tuple import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.io.streamline import save_tractogram +from scilpy.viz.utils import get_colormap -def add_attention_as_dpp(sft: StatefulTractogram, lengths, - attentions_per_line: Tuple, attention_name: Tuple): +def save_sft_with_attention_as_dpp( + sft: StatefulTractogram, lengths, prefix_name: str, + attentions_per_line: Tuple, attention_names: Tuple, + sft_delta: bool = True, sft_all_info: bool = True): """ Adds the attention's value to the data per point. @@ -15,11 +18,12 @@ def add_attention_as_dpp(sft: StatefulTractogram, lengths, ---------- sft: StatefulTractogram lengths: list + prefix_name: str attentions_per_line: Tuple. For each attention: - List of length nb_streamlines - Such as received by unpad_rescale_attention. - Each is: List[np.array] of length nb lines - attention_name: Tuple[str] + List of length nb_streamlines (after unpad_rescale_attention). + Each is: List[np.array] of length nb layers. + Each attention is of shape: [nb_heads, line_length, line_length] + attention_names: Tuple[str] """ assert len(attentions_per_line[0]) == len(sft), \ ("Expecting attention to be one list per line for {} streamlines, " @@ -37,38 +41,97 @@ def add_attention_as_dpp(sft: StatefulTractogram, lengths, nb_layers = len(attentions_per_line[0][0]) nb_heads = attentions_per_line[0][0][0].shape[0] - remaining_streamlines = sft.streamlines - whole_sft = None - # Converting Tuple to list for easier management attentions_per_line = list(attentions_per_line) + if sft_delta: + _save_sft_delta(sft, prefix_name, attentions_per_line, + attention_names, nb_layers, nb_heads) + + if sft_all_info: + _save_sft_all_info(sft, lengths, prefix_name, attentions_per_line, + attention_names, nb_layers, nb_heads) + + +def _save_sft_delta(sft: StatefulTractogram, prefix_name: str, + attentions_per_line: list, attention_names: Tuple[str], + nb_layers, nb_heads): + + for i, att_type in enumerate(attentions_per_line): + for l in range(nb_layers): + for h in range(nb_heads): + dpp_max = [] + dpp_nb = [] + for s in range(len(sft.streamlines)): + weights = att_type[s][l][h, :, :] + max_ = np.argmax(weights, axis=1, + keepdims=True).astype(float) + + # Currently max_ is the index of the point with maximal + # attention. Changing to be a ratio of the streamline + # length, independent of the length. + max_ /= len(sft.streamlines[s]) + dpp_max.append(max_) + + # Mean, mean weighted, etc.: Do not seem to represent much. + THRESH = 0.5 + nb_points_above_thresh = np.sum(weights > THRESH, axis=1, + keepdims=True).astype(float) + dpp_nb.append(nb_points_above_thresh) + + prefix = attention_names[i] + 'l{}_h{}'.format(l, h) + + dpp = prefix + '_max' + tractogram_name = prefix_name + dpp + '.trk' + sft.data_per_point[dpp] = dpp_max + sft = color_sft_from_dpp(sft, dpp) + print("Saving tractogram {}".format(tractogram_name)) + save_tractogram(sft, tractogram_name) + del sft.data_per_point[dpp] + del sft.data_per_point['color'] + + dpp = prefix + '_nb' + tractogram_name = prefix_name + dpp + '.trk' + sft.data_per_point[dpp] = dpp_nb + sft = color_sft_from_dpp(sft, dpp) + print("Saving tractogram {}".format(tractogram_name)) + save_tractogram(sft, tractogram_name) + del sft.data_per_point[dpp] + del sft.data_per_point['color'] + + +def _save_sft_all_info(sft: StatefulTractogram, lengths, prefix_name: str, + attentions_per_line: list, attention_names: Tuple, + nb_layers, nb_heads): + remaining_streamlines = sft.streamlines + whole_sft = None + # Starting current point at length 2. At length 1, we know that it only # looked at the first point. for current_point in range(2, max(lengths)): # The nth point of each streamline, if long enough - # Removing shorter streamlines from attention - for i, att in enumerate(attentions_per_line): - attentions_per_line[i] = \ - [a for a, s in zip(att, remaining_streamlines) - if len(s) > current_point] + # Removing shorter streamlines from each type of attention + # (encoder, decoder, cross) + for i, att_type in enumerate(attentions_per_line): + attentions_per_line[i] = [line_att for line_att, s in + zip(att_type, remaining_streamlines) + if len(s) > current_point] # Removing shorter streamlines for list of streamlines remaining_streamlines = [s for s in remaining_streamlines if len(s) > current_point] - # Saving first part of streamlines, up to current_point. - # At current_point: which point did we look at? - # Saving many ddp key for these streamlines: per layer, per head. + # Saving first part of streamlines, up to current_point: + # = "At current_point: which point did we look at?" tmp_sft = sft.from_sft([s[0:current_point] for s in remaining_streamlines], sft) - logging.debug("Adding the first {} points of the remaining {} " - "streamlines" - .format(current_point+1, len(remaining_streamlines))) - for layer in range(nb_layers): - for head in range(nb_heads): - for att, name in zip(attentions_per_line, attention_name): + + # Saving many ddp key for these streamlines: per layer, per head. + for att_nb, att_type in enumerate(attentions_per_line): + name = attention_names[att_nb] + for layer in range(nb_layers): + for head in range(nb_heads): # Adding data per point: attention_name_layerX_headX # (Careful. Nibabel force names to be <18 character) # attention_lX_hX @@ -79,13 +142,45 @@ def add_attention_as_dpp(sft: StatefulTractogram, lengths, # Nibabel required data_per_point to have the same number of # dimensions as the streamlines (N, 3) = 2D. Adding a fake # second dimension. - ddp = [a[layer][head, current_point, :current_point][:, None] - for a in att] - tmp_sft.data_per_point[name + suffix] = ddp + dpp = [a[layer][head, current_point, :current_point][:, None] + for a in att_type] + + tmp_sft.data_per_point[name + suffix] = dpp if whole_sft is None: whole_sft = tmp_sft else: whole_sft = whole_sft + tmp_sft - return whole_sft + print(" **The initial {} streamlines were transformed into {} " + "streamlines of \n" + " variable lengths. Color for streamline i of length N is the " + "attention's value \n" + " at each point when deciding the next direction at point N." + .format(len(sft), len(whole_sft.streamlines))) + del sft + del tmp_sft + + dpp_keys = list(whole_sft.data_per_point.keys()) + for key in dpp_keys: + name = prefix_name + '_colored_sft_' + key + '.trk' + # Keep only current key + colored_sft = whole_sft.from_sft( + whole_sft.streamlines, whole_sft, + data_per_point={key: whole_sft.data_per_point[key]}) + colored_sft = color_sft_from_dpp(colored_sft, key) + + print("Saving {} with dpp: {}" + .format(name, list(colored_sft.data_per_point.keys()))) + save_tractogram(colored_sft, name) + + +def color_sft_from_dpp(sft, key): + cmap = get_colormap('jet') + data = np.squeeze(sft.data_per_point[key]._data) + data = data - np.min(data) + data = data / np.max(data) + color = cmap(data)[:, 0:3] * 255 + sft.data_per_point['color'] = sft.streamlines + sft.data_per_point['color']._data = color + return sft diff --git a/dwi_ml/testing/projects/tt_visu_main.py b/dwi_ml/testing/projects/tt_visu_main.py index c6cadc37..ae8d076e 100644 --- a/dwi_ml/testing/projects/tt_visu_main.py +++ b/dwi_ml/testing/projects/tt_visu_main.py @@ -8,32 +8,33 @@ import glob import logging import os -from typing import List import numpy as np import torch -from dipy.io.streamline import save_tractogram from matplotlib import pyplot as plt from scilpy.io.fetcher import get_home as get_scilpy_folder from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg) +from scilpy.utils.streamlines import uniformize_bundle_sft -from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path, add_memory_args, \ - add_arg_existing_experiment_path +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_logging_arg, add_memory_args, + verify_which_model_in_path) from dwi_ml.models.projects.transformer_models import ( OriginalTransformerModel, TransformerSrcAndTgtModel) from dwi_ml.models.projects.transformers_utils import find_transformer_class from dwi_ml.testing.projects.tt_visu_bertviz import ( encoder_decoder_show_head_view, encoder_decoder_show_model_view, encoder_show_model_view, encoder_show_head_view) -from dwi_ml.testing.testers import TesterOneInput -from dwi_ml.testing.projects.tt_visu_colored_sft import add_attention_as_dpp +from dwi_ml.testing.projects.tt_visu_colored_sft import ( + save_sft_with_attention_as_dpp) from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow from dwi_ml.testing.projects.tt_visu_utils import ( prepare_encoder_tokens, prepare_decoder_tokens, - reshape_attention_to4d_tocpu, unpad_rescale_attention, resample_attention_one_line) + reshape_unpad_rescale_attention) +from dwi_ml.testing.testers import TesterOneInput from dwi_ml.testing.utils import add_args_testing_subj_hdf5 @@ -78,7 +79,7 @@ def build_argparser_transformer_visu(): '--out_prefix', metavar='name', help="Prefix of the all output files. Do not include a path. " "Suffixes are: \n" - " 1) 'as_matrix': tt_matrix_[encoder|decoder|cross].png.\n" + " 1) 'as_matrices': tt_matrix_[encoder|decoder|cross].png.\n" " 2) 'bertviz': tt_bertviz.html, tt_bertviz.ipynb, " "tt_bertviz.config.\n" " 3) 'colored_sft': colored_sft.trk." @@ -90,9 +91,9 @@ def build_argparser_transformer_visu(): p.add_argument( '--visu_type', required=True, nargs='+', - choices=['as_matrix', 'bertviz', 'colored_sft', 'bertviz_locally'], + choices=['as_matrices', 'bertviz', 'colored_sft', 'bertviz_locally'], help="Output option. Choose any number (at least one). \n" - " 1) 'as_matrix': Show attention as matrices. \n" + " 1) 'as_matrices': Show attention as matrices. \n" " If bertviz is also chosen, matrices will show in the " "html.\n" " 2) 'bertviz': Show using bertviz head_view visualization.\n" @@ -108,24 +109,35 @@ def build_argparser_transformer_visu(): help="If true, rescale to max 1 per row.") g = p.add_mutually_exclusive_group() - g.add_argument('--align_endpoints', action='store_true', + g.add_argument('--uniformize_endpoints', action='store_true', help="If set, align endpoints of the batch. Either this " - "or --inverse_align_endpoints. \nProbably helps" + "or --inverse_uniformize_endpoints. \nProbably helps" "visualisation with option --visu_type 'colored_sft'.") - g.add_argument('--inverse_align_endpoints', action='store_true', + g.add_argument('--inverse_uniformize_endpoints', action='store_true', help="If set, aligns endpoints and then reverses the " "bundle.") - p.add_argument('--resample_attention', type=int, + g.add_argument('--flip', action='store_true', + help="If set, flip all streamlines.") + p.add_argument('--axis', choices=['x', 'y', 'z'], + help='When uniformizing endpoints, match endpoints of the ' + 'streamlines along this axis. If not set: discover ' + 'the best axis (auto).' + '\nSUGGESTION: Commissural = x, Association = y, ' + 'Projection = z') + + p.add_argument('--resample_attention', type=int, metavar='nb', help="Streamline will be sampled as decided by the model." "However, \nattention will be resampled to fit better " "in the html page. \n (Resampling is done by " "averaging the attention every N points).\n" - "(only for bertviz and as_matrix") + "(only for bertviz and as_matrices") p.add_argument('--average_heads', action='store_true', help="If true, resample all heads (per layer per " "attention type).") - p.add_argument('--batch_size', type=int) + p.add_argument('--batch_size', type=int, metavar='n', + help="Batch size in number of streamlines. If not set, " + "uses all streamlines in one batch.") add_memory_args(p) p.add_argument('--show_now', action='store_true', @@ -150,15 +162,15 @@ def tt_visualize_weights_main(args, parser): show_as_matrices = False if 'colored_sft' in args.visu_type: save_colored_sft = True - if 'as_matrix' in args.visu_type: + if 'as_matrices' in args.visu_type: show_as_matrices = True if 'bertviz' in args.visu_type or 'bertviz_locally' in args.visu_type: run_bertviz = True if save_colored_sft and not (show_as_matrices or run_bertviz) and \ - args.resample_attention_one_line is not None: + args.resample_attention is not None: logging.warning("We only resample attention when visualizing matrices " - "or bertviz. Not required with colored_sft.") + "or bertviz. Not required with colored_sft. Ignoring.") # -------- Verify inputs and outputs assert_inputs_exist(parser, [args.hdf5_file, args.in_sft]) @@ -168,18 +180,23 @@ def tt_visualize_weights_main(args, parser): # Out files: jupyter stuff already managed in main script. Remains the sft. args = set_out_dir_visu_weights_and_create_if_not_exists(args) out_files = [] - out_sft = None prefix_total = os.path.join(args.out_dir, args.out_prefix) if save_colored_sft: - out_sft = prefix_total + '_colored_sft.trk' - out_files.append(out_sft) + # Total sft names will be, ex: + # prefix_total + _colored_sft_encoder_layerX_headX.trk + any_existing = glob.glob(prefix_total + '_colored_sft_*.trk') + out_files.extend(any_existing) if show_as_matrices: # Total matrices names will be, ex: - # prefix_total + encoder_matrix_layerX.png - any_existing = glob.glob(args.out_dir + '/*_matrix_layer*.png') + # prefix_total + _matrix_encoder_layerX_headX.png + any_existing = glob.glob(prefix_total + '_matrix_*.png') out_files.extend(any_existing) assert_outputs_exist(parser, args, out_files) + if args.overwrite: + for f in out_files: + if os.path.isfile(f): + os.remove(f) sub_logger_level = 'WARNING' logging.getLogger().setLevel(level=args.logging) @@ -195,12 +212,6 @@ def tt_visualize_weights_main(args, parser): else: device = torch.device('cpu') - # Load SFT - logging.info("Loading analysed bundle. Note that space comptability " - "with training data will NOT be verified.") - sft = load_tractogram_with_reference(parser, args, args.in_sft) - logging.debug(" Got {} streamlines.".format(len(sft))) - # Load model logging.debug("Loading the model") model_dir = os.path.join(args.experiment_path, 'best_model') @@ -210,6 +221,16 @@ def tt_visualize_weights_main(args, parser): model = model_cls.load_model_from_params_and_state( model_dir, log_level=sub_logger_level) + # Load SFT + logging.info("Loading analysed bundle. Note that space comptability " + "with training data will NOT be verified.") + sft = load_tractogram_with_reference(parser, args, args.in_sft) + sft.to_vox() + sft.to_corner() + + sft = sft[0:1] # - ------------------------ DEBUGGING + logging.debug(" Got {} streamlines.".format(len(sft))) + if len(sft) > 1 and not save_colored_sft: # Taking only one streamline line_id = np.random.randint(0, len(sft), size=1)[0] @@ -218,11 +239,12 @@ def tt_visualize_weights_main(args, parser): .format(line_id, len(sft))) sft = sft[[line_id]] - if args.align_endpoints: - raise NotImplementedError - - if args.inverse_align_endpoints: - raise NotImplementedError + if args.uniformize_endpoints or args.inverse_uniformize_endpoints: + # Done in-place + uniformize_bundle_sft(sft, args.axis, + swap=args.inverse_uniformize_endpoints) + elif args.flip: + sft.streamlines = [np.flip(line, axis=0) for line in sft.streamlines] logging.debug("Loading the data...") tester = TesterOneInput( @@ -234,9 +256,10 @@ def tt_visualize_weights_main(args, parser): model.set_context('visu_weights') sft, outputs, _, _ = tester.run_model_on_sft(sft) + # Resulting weights is a tuple of one list per attention type. + # Each list is: one tensor per layer. _, weights = outputs - logging.info("Preparing visu...") if isinstance(model, OriginalTransformerModel): visu_fct = visu_encoder_decoder elif isinstance(model, TransformerSrcAndTgtModel): @@ -244,34 +267,21 @@ def tt_visualize_weights_main(args, parser): else: # TransformerSrcOnlyModel visu_fct = visu_encoder_only - visu_fct(weights, sft, args.resample_attention, args.rescale, - model.direction_getter.add_eos, save_colored_sft, - run_bertviz, show_as_matrices, colored_sft_name=out_sft, - matrices_prefix=prefix_total) + visu_fct(weights, sft, model.direction_getter.add_eos, + args.average_heads, args.resample_attention, args.rescale, + save_colored_sft, run_bertviz, show_as_matrices, prefix_total) if args.show_now: plt.show() def visu_encoder_decoder( - weights, sft, resample_nb: int, rescale: bool, has_eos: bool, + weights, sft, has_eos: bool, + average_heads: bool, resample_nb: int, rescale: bool, save_colored_sft: bool, run_bertviz: bool, show_as_matrices: bool, - colored_sft_name: str, matrices_prefix: str): + prefix_name: str): """ Visualizing the 3 attentions. - - Parameters - ---------- - weights: Tuple of length 3 - sft: StatefulTractogram - resample_nb: int - rescale: bool - has_eos: bool - save_colored_sft: bool - run_bertviz: bool - show_as_matrices: bool - colored_sft_name: str, or None if not save_colored_sft - matrices_prefix: str, or None if not show_as_matrices """ encoder_attention, decoder_attention, cross_attention = weights @@ -281,35 +291,21 @@ def visu_encoder_decoder( sft.streamlines = [s[:-1, :] for s in sft.streamlines] lengths = [len(s) for s in sft.streamlines] - encoder_attention = reshape_attention_to4d_tocpu(encoder_attention) - decoder_attention = reshape_attention_to4d_tocpu(decoder_attention) - cross_attention = reshape_attention_to4d_tocpu(cross_attention) - - encoder_attention = unpad_rescale_attention(encoder_attention, lengths, - rescale) - decoder_attention = unpad_rescale_attention(decoder_attention, lengths, - rescale) - cross_attention = unpad_rescale_attention(cross_attention, lengths, - rescale) + encoder_attention = reshape_unpad_rescale_attention( + encoder_attention, average_heads, lengths, rescale) + decoder_attention = reshape_unpad_rescale_attention( + decoder_attention, average_heads, lengths, rescale) + cross_attention = reshape_unpad_rescale_attention( + cross_attention, average_heads, lengths, rescale) if save_colored_sft: print("\n\n-------------- Preparing the data_per_point to color sft " "--------------") - colored_sft = add_attention_as_dpp( - sft, lengths, + save_sft_with_attention_as_dpp( + sft, lengths, prefix_name, (encoder_attention, decoder_attention, cross_attention), ('encoder', 'decoder', 'cross')) - print(" **Saving tractogram with color as ddp.** \n" - "The initial {} streamlines were transformed into {} " - "streamlines of \nvariable length. Color for streamline i of " - "length N is the attention's\n value at each point when " - "deciding the next direction at point N." - .format(len(sft), len(colored_sft))) - save_tractogram(colored_sft, colored_sft_name) - - del colored_sft - if run_bertviz or show_as_matrices: print("\n\n-------------- Preparing the attention as a matrix for one " "streamline --------------") @@ -346,35 +342,26 @@ def visu_encoder_decoder( if show_as_matrices: print("ENCODER ATTENTION: ") show_model_view_as_imshow(encoder_attention, - matrices_prefix + '_encoder_matrix', + prefix_name + '_matrix_encoder', encoder_tokens, encoder_tokens) print("DECODER ATTENTION: ") show_model_view_as_imshow(decoder_attention, - matrices_prefix + '_decoder_matrix', + prefix_name + '_matrix_decoder', decoder_tokens, decoder_tokens) print("CROSS ATTENTION: ") show_model_view_as_imshow(cross_attention, - matrices_prefix + 'cross_attention_matrix', + prefix_name + '_matrix_cross_attention', encoder_tokens, decoder_tokens) def visu_encoder_only( - weights, sft, resample_nb: int, rescale: bool, has_eos: bool, + weights, sft, has_eos: bool, + average_heads: bool, resample_nb: int, rescale: bool, save_colored_sft: bool, run_bertviz: bool, show_as_matrices: bool, - colored_sft_name: str, matrices_prefix: str): + prefix_name: str): """ Visualizing one attention. - Parameters - ---------- - weights: Tuple of length 1 - sft: StatefulTractogram - resample_nb: int - rescale: bool - has_eos: bool - save_colored_sft: bool - run_bertviz: bool - show_as_matrices: bool colored_sft_name: str, or None if not save_colored_sft matrices_prefix: str, or None if not show_as_matrices """ @@ -392,25 +379,19 @@ def visu_encoder_only( sft.streamlines = [s[:-1, :] for s in sft.streamlines] lengths = [len(s) for s in sft.streamlines] - encoder_attention = reshape_attention_to4d_tocpu(encoder_attention) - encoder_attention = unpad_rescale_attention(encoder_attention, lengths, - rescale) + # Reshaping all + encoder_attention = reshape_unpad_rescale_attention( + encoder_attention, average_heads, lengths, rescale, resample_nb) + if resample_nb: + sft = resample_streamlines_num_points(sft, num_points=resample_nb) + # Right now attention = list per streamline. + # (of a list per layer) if save_colored_sft: print("\n\n-------------- Preparing the data_per_point to color sft " "--------------") - colored_sft = add_attention_as_dpp(sft, lengths, (encoder_attention,), - ('encoder',)) - - print(" **Saving tractogram with color as ddp.** \n" - "The initial {} streamlines were transformed into {} " - "streamlines of \nvariable length. Color for streamline i of " - "length N is the attention's\n value at each point when " - "deciding the next direction at point N." - .format(len(sft), len(colored_sft))) - save_tractogram(colored_sft, colored_sft_name) - - del colored_sft + save_sft_with_attention_as_dpp(sft, lengths, prefix_name, + (encoder_attention,), ('encoder',)) if run_bertviz or show_as_matrices: print("\n\n-------------- Preparing the attention as a matrix for one " @@ -424,17 +405,20 @@ def visu_encoder_only( encoder_attention = encoder_attention[0] this_seq_len = lengths[0] - encoder_attention, inds = resample_attention_one_line( - encoder_attention, this_seq_len, resample_nb=resample_nb) - encoder_tokens = prepare_encoder_tokens(this_seq_len, has_eos, inds) - - if run_bertviz: - encoder_show_head_view(encoder_attention, encoder_tokens) - encoder_show_model_view(encoder_attention, encoder_tokens) + encoder_tokens = prepare_encoder_tokens(this_seq_len, has_eos) if show_as_matrices: print("ENCODER ATTENTION: ") show_model_view_as_imshow(encoder_attention, - matrices_prefix + '_encoder_matrix', + prefix_name + '_matrix_encoder', encoder_tokens) + + # Sending to 4D torch for Bertviz (each layer) + encoder_attention = [torch.as_tensor(att)[None, :, :, :] + for att in encoder_attention] + + if run_bertviz: + encoder_show_head_view(encoder_attention, encoder_tokens) + encoder_show_model_view(encoder_attention, encoder_tokens) + diff --git a/dwi_ml/testing/projects/tt_visu_matrix.py b/dwi_ml/testing/projects/tt_visu_matrix.py index f5000fd6..03427753 100644 --- a/dwi_ml/testing/projects/tt_visu_matrix.py +++ b/dwi_ml/testing/projects/tt_visu_matrix.py @@ -5,10 +5,10 @@ from matplotlib import pyplot as plt -def show_model_view_as_imshow(attention, fig_prefix, tokens_x, tokens_y=None): - torch.set_printoptions(precision=2, sci_mode=False, linewidth=150) +def show_model_view_as_imshow(attention_one_line, fig_prefix, + tokens_x, tokens_y=None): - nb_layers = len(attention) + nb_layers = len(attention_one_line) size_x = len(tokens_x) if tokens_y is None: @@ -19,15 +19,19 @@ def show_model_view_as_imshow(attention, fig_prefix, tokens_x, tokens_y=None): cmap.set_bad(color='black') for i in range(nb_layers): - att = attention[i].numpy() - nb_heads = att.shape[1] - fig, axs = plt.subplots(1, nb_heads, figsize=(20, 8)) + att = attention_one_line[i] + nb_heads = att.shape[0] + fig, axs = plt.subplots(1, nb_heads, figsize=(20, 8), + layout='compressed') + if nb_heads == 1: + axs = [axs] for h in range(nb_heads): - a = np.squeeze(att[0, h, :, :]) + a = np.squeeze(att[h, :, :]) a = np.ma.masked_where(a == 0, a) im = axs[h].imshow(a, interpolation='None') - axs[h].set_title("Head {}".format(h)) + if nb_heads > 1: + axs[h].set_title("Head {}".format(h)) axs[h].set_xticks(np.arange(size_x), fontsize=10) axs[h].set_yticks(np.arange(size_y), fontsize=10) axs[h].set_xticklabels(tokens_x, rotation=-90) @@ -42,7 +46,6 @@ def show_model_view_as_imshow(attention, fig_prefix, tokens_x, tokens_y=None): "directions?\n" "DATA IS NORMALIZED TO [0-1] RANGE PER ROW".format(i)) - fig.tight_layout() name = fig_prefix + '_layer{}.png'.format(i) print("Saving matrix : {}".format(name)) plt.savefig(name) diff --git a/dwi_ml/testing/projects/tt_visu_utils.py b/dwi_ml/testing/projects/tt_visu_utils.py index b17c202c..efceedf2 100644 --- a/dwi_ml/testing/projects/tt_visu_utils.py +++ b/dwi_ml/testing/projects/tt_visu_utils.py @@ -3,91 +3,94 @@ from typing import List, Tuple import numpy as np -from skimage.measure import block_reduce import torch +from scipy.ndimage import zoom +from tqdm import tqdm -def reshape_attention_to4d_tocpu(attention): +def reshape_unpad_rescale_attention(attention_per_layer, average_heads: bool, + lengths, rescale, resample_nb): """ Also sends to CPU. Parameters ---------- - attention: List[Tensor] + attention_per_layer: List[Tensor] Attention such as received directly from the Transformer. A list of len nb_layers with tensors: [nb_streamlines, batch_max_len, batch_max_len] --> If averaged heads [nb_streamlines, nheads, batch_max_len, batch_max_len] --> Else. - - Returns - ------- - attention: List[Tensor] - A list of len nb_layers with tensors: - [nb_streamlines, nheads, batch_max_len, batch_max_len] - Where nheads=1 if average_head. - """ - for ll in range(len(attention)): # Per layer - if len(attention[ll].shape) == 3: - # No head dimension if heads were averaged. Bertviz requires 4D. - attention[ll] = attention[ll][:, None, :, :] - attention[ll] = attention[ll].cpu() - - return attention - - -def unpad_rescale_attention(attention, lengths, rescale): - """ - Reformats the attention to have always the same dimension, regardless of - the model. Unpads the result. Possibly, downsample the attention for nicer - visualisation. - - Parameters - ---------- - attention: List - Attention after running reshape_attention. - List of len nb_layers of tensors: - [nb_streamlines, nb_head, max_length, max_length] + average_heads: bool + If true, average heads. lengths: List[int] Unpadded lengths of the streamlines. rescale: bool + resample_nb: int or None Returns ------- - attention: List[List[np.array]] - Length: [nb_streamlines x [nb_layers x array]] - Arrays are of shape [nb_heads, this_s_len, this_s_len] - (nb_heads = 1 if average_heads). + attention: List[np.ndarray] + A list of len nb_streamlines with, each: + A list of len nb_layers of np.ndarray of shape: + [nheads, batch_max_len, batch_max_len] + Where nheads=1 if average_head. """ - assert attention[0].shape[0] == len(lengths), \ - ("Expecting attention to be, for each layer, a tensor of shape " - "[nb_streamlines, ...] but got shape 0={} (expected {})" - .format(attention[0].shape[0], len(lengths))) + nb_layers = len(attention_per_layer) + nb_streamlines = len(lengths) if rescale: logging.info("We will normalize the attention: per row, to the range " - "[0, 1]: \nthe attention when deciding the next " - "direction at point N is distributed in the N first " - "points of the streamline such \nthat the point with " - "most attention has value 1. (att = att/max)") - - nb_layers = len(attention) - - # A list of (one 4D attention per layer) per line + "[0, 1]: \n" + " The attention when deciding the next direction at " + "point N is \n" + " distributed in the N first points of the streamline " + "such that \n" + " the point with most attention has value 1. " + "(att = att/max)") + + # 1. Rearrange attention per layer to 4D + for ll in range(nb_layers): + # To numpy arrays + attention_per_layer[ll] = attention_per_layer[ll].cpu().numpy() + + # Averaging heads (but keeping 4D). + if average_heads: + attention_per_layer[ll] = np.mean(attention_per_layer[ll], + axis=1, keepdims=True) + + assert attention_per_layer[ll].shape[0] == nb_streamlines, \ + ("Expecting attention to be, for each layer, a tensor of shape " + "[nb_streamlines={}, nb_heads, max_len, max_len] but got " + "shape[0] = {}." + .format(len(lengths), attention_per_layer[0].shape[0])) + + nb_heads = attention_per_layer[0].shape[1] + + # 2. Rearrange attention into one list per line, unpadded, rescaled. attention_per_line = [] - for line in range(len(lengths)): + for line in tqdm(range(len(lengths)), total=len(lengths), + desc="Rearranging, unpadding, rescaling (if asked)", + maxinterval=3): this_seq_len = lengths[line] - # Unpad, rescale attention_per_line.append([None] * nb_layers) for i in range(nb_layers): - # Easier to work with numpy. Will put back to tensor after. - att = attention[i].numpy() - assert len(att.shape) == 4 - # 1. Unpadding. Taking one streamline. - att = att[line, :, 0:this_seq_len, 0:this_seq_len] + att = attention_per_layer[i][line, :, 0:this_seq_len, 0:this_seq_len] + + # 2. Resampling + if resample_nb and this_seq_len > resample_nb: + new_att = np.zeros((nb_heads, resample_nb, resample_nb)) + for h in range(nb_heads): + ratio = resample_nb / this_seq_len + result = zoom(att[h, :, :], zoom=ratio, + order=3, mode='nearest') + + # Verifying that the future is still 0. + result = np.tril(result) + new_att[h, :, :] = result - # 2. Normalizing weight. Without it, we rapidly see nothing! + # 3. Normalizing weight. Without it, we rapidly see nothing! # Easier to see when we normalize on the x axis. # Normalizing each row so that max value = 1. if rescale: @@ -99,89 +102,6 @@ def unpad_rescale_attention(attention, lengths, rescale): return attention_per_line -def _verify_resampling(resample_nb, this_seq_len): - ind = None - this_resample_nb = resample_nb - nb_together = None - - if resample_nb < this_seq_len: - tmp = this_seq_len / resample_nb - nb_together = np.round(tmp) - real_resample_attention = int(np.ceil(this_seq_len / nb_together)) - nb_together = int(nb_together) - - if tmp < nb_together: - alt = np.floor(tmp) - else: - alt = np.ceil(tmp) - logging.debug( - "You asked to resample the attention from {} to {}.\n" - " --> By combining every {} points: matrix of size {} " - "(chosen)\n" - " --> By combining every {} points: matrix of size {}.\n" - " (We have not yet implemented an irregular resampling " - "of the attention.) " - .format(this_seq_len, resample_nb, nb_together, - real_resample_attention, int(alt), - int(np.ceil(this_seq_len / alt)))) - - if nb_together > 1: - ind1 = np.arange(0, this_seq_len, nb_together) - ind = [(i, min(i + nb_together, this_seq_len - 1)) - for i in ind1] - else: - this_resample_nb = this_seq_len - - return this_resample_nb, ind, nb_together - - -def resample_attention_one_line(att, this_seq_len, - resample_nb: int = None): - """ - Parameters - ---------- - att: - Such as received by unpad_rescale_attention, one line only. - this_seq_len: int - Unpadded lengths of the streamlines. - resample_nb: int - The final number of points of the attention. - """ - assert isinstance(att[0], np.ndarray), \ - ("Expecting attention to be a list, per layer, of np.ndarray, got {}" - .format(type(att[0]))) - assert att[0].shape[1] == this_seq_len, \ - ("Expecting attention to be unpadded. For each layer, should be " - "of shape [nb_heads, seq_len, seq_len], but got shape[1] = {} " - "(expecting {})".format(att[0].shape[1], this_seq_len)) - - if resample_nb is None: - resample_nb = 100000 - - nb_layers = len(att) - - # 1. Verifying if we need to resample for this streamline - this_resample_nb, inds, nb_together = _verify_resampling( - resample_nb, this_seq_len) - - # 2. Resample - for i in range(nb_layers): - if this_resample_nb < this_seq_len: - # No option to pad to edge value in block_reduce. Doing manually. - missing = (len(inds) * nb_together) - att[i].shape[2] - att[i] = np.pad(att[i], ((0, 0), (0, missing), (0, missing)), - mode='edge') - - att[i] = block_reduce( - att[i], block_size=(1, nb_together, nb_together), - func=np.max, cval=1000.0) # 1000: to see if bug. - - # Sending to 4D torch for Bertviz - att[i] = torch.as_tensor(att[i])[None, :, :, :] - - return att, inds - - def prepare_decoder_tokens(this_seq_len, ind: List[Tuple]): if ind is not None: @@ -196,17 +116,12 @@ def prepare_decoder_tokens(this_seq_len, ind: List[Tuple]): return decoder_tokens -def prepare_encoder_tokens(this_seq_len, add_eos: bool, ind: List[Tuple]): +def prepare_encoder_tokens(this_seq_len, add_eos: bool): # If encoder = concat X | Y, then , point0 = point0 | SOS # point1 = point1 | dir0 # etc. But ok. We will understand. - if ind is not None: - # Used resample_attention - encoder_tokens = ['points {}-{}'.format(i[0], i[1] - 1) - for i in ind] - else: - encoder_tokens = ['point {}'.format(i) for i in range(this_seq_len)] + encoder_tokens = ['point {}'.format(i) for i in range(this_seq_len)] if add_eos: encoder_tokens[-1] += '(SOS)' diff --git a/dwi_ml/testing/projects/tt_visualize_weights.ipynb b/dwi_ml/testing/projects/tt_visualize_weights.ipynb index f4e497b6..4e9a39de 100644 --- a/dwi_ml/testing/projects/tt_visualize_weights.ipynb +++ b/dwi_ml/testing/projects/tt_visualize_weights.ipynb @@ -58,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"RUNNING WEIGHTS VISUALIZATION FOR TRANSFORMING TRACTOGRAPHY'S ORIGINAL MODEL\")\n", + "print(\"RUNNING WEIGHTS VISUALIZATION FOR TRANSFORMING TRACTOGRAPHY'S MODEL\")\n", "parser = build_argparser_transformer_visu()\n", "args = parser.parse_args()\n", "tt_visualize_weights_main(args, parser)" diff --git a/dwi_ml/testing/testers.py b/dwi_ml/testing/testers.py index d871b5c1..f6eb0e56 100644 --- a/dwi_ml/testing/testers.py +++ b/dwi_ml/testing/testers.py @@ -6,7 +6,8 @@ from dwi_ml.data.processing.streamlines.data_augmentation import \ resample_or_compress -from dwi_ml.models.main_models import MainModelOneInput, MainModelAbstract, ModelWithDirectionGetter +from dwi_ml.models.main_models import (MainModelOneInput, MainModelAbstract, + ModelWithDirectionGetter) from dwi_ml.testing.utils import prepare_dataset_one_subj logger = logging.getLogger('tester_logger') @@ -87,11 +88,13 @@ def run_model_on_sft(self, sft, compute_loss=False): Returns ------- - outputs: - - Gaussian model: outputs = ([], []) - - Fisher Von mises: not Implemented - - Other: [] + sft: StatefulTractogram + The tractogram, formatted as required by your model: + to_vox, to_corner, possibly resampled or compressed. + outputs: Any + Your model output. losses: + """ sft = resample_or_compress(sft, self.model.step_size, self.model.compress_lines) diff --git a/scripts_python/tt_visualize_weights.py b/scripts_python/tt_visualize_weights.py index 14488bbe..8da9a748 100644 --- a/scripts_python/tt_visualize_weights.py +++ b/scripts_python/tt_visualize_weights.py @@ -9,7 +9,8 @@ from dwi_ml.testing.projects.tt_visu_main import ( build_argparser_transformer_visu, get_config_filename, - tt_visualize_weights_main, set_out_dir_visu_weights_and_create_if_not_exists) + set_out_dir_visu_weights_and_create_if_not_exists, + tt_visualize_weights_main) # Note. To use through jupyter, the file