From 5149706fa48aa6382c7d7a09d5a01b496686da04 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 20 Sep 2024 12:51:58 -0400 Subject: [PATCH 01/28] Latent space visualization integration --- dwi_ml/models/projects/ae_models.py | 34 ++-- dwi_ml/viz/latent_streamlines.py | 183 +++++++++++++++++++++ scripts_python/ae_visualize_bundles.py | 106 ++++++++++++ scripts_python/ae_visualize_streamlines.py | 44 ++--- 4 files changed, 322 insertions(+), 45 deletions(-) create mode 100644 dwi_ml/viz/latent_streamlines.py create mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..50aace59 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -31,6 +31,7 @@ def __init__(self, kernel_size, latent_space_dims, self.latent_space_dims = latent_space_dims self.pad = torch.nn.ReflectionPad1d(1) + self.post_encoding_hooks = [] def pre_pad(m): return torch.nn.Sequential(self.pad, m) @@ -137,13 +138,20 @@ def forward(self, `get_tracking_directions()`. """ - x = self.decode(self.encode(input_streamlines)) + encoded = self.encode(input_streamlines) + + for hook in self.post_encoding_hooks: + hook(encoded) + + x = self.decode(encoded) return x def encode(self, x): - # x: list of tensors - x = torch.stack(x) - x = torch.swapaxes(x, 1, 2) + # X input shape is (batch_size, nb_points, 3) + if isinstance(x, list): + x = torch.stack(x) + + x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points) h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) @@ -181,21 +189,13 @@ def decode(self, z): return h11 def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="sum") + reconstruction_loss = torch.nn.MSELoss() mse = reconstruction_loss(model_outputs, targets) - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) - return mse, 1 + + def register_hook_post_encoding(self, hook): + self.post_encoding_hooks.append(hook) + diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py new file mode 100644 index 00000000..297bced6 --- /dev/null +++ b/dwi_ml/viz/latent_streamlines.py @@ -0,0 +1,183 @@ +import logging + +from typing import Union, List, Tuple +from sklearn.manifold import TSNE +import numpy as np +import torch + +import matplotlib.pyplot as plt + +def plot_latent_streamlines( + encoded_streamlines: Union[np.ndarray, torch.Tensor], + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Projects and plots the latent space representation + of the streamlines using t-SNE dimensionality reduction. + + Parameters + ---------- + encoded_streamlines: Union[np.ndarray, torch.Tensor] + Latent space streamlines to plot of shape (N, latent_space_dim). + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: int + Random state for t-SNE. + max_subset_size: int: + In case of performance issues, you can limit the number of streamlines to plot. + """ + + if isinstance(encoded_streamlines, torch.Tensor): + latent_space_streamlines = encoded_streamlines.cpu().numpy() + else: + latent_space_streamlines = encoded_streamlines + + if max_subset_size is not None: + if not (max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + # Project the data into 2 dimensions. + tsne = TSNE(n_components=2, random_state=random_state) + X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) + + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if fig_size is not None: + fig.set_figheight(fig_size[0]) + fig.set_figwidth(fig_size[1]) + + ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) + + if save_path is not None: + fig.savefig(save_path) + else: + plt.show() + + +class BundlesLatentSpaceVisualizer(object): + """ + Utility class that wraps a t-SNE projection of the latent space for multiple bundles. + The usage of this class is intented as follows: + 1. Create an instance of this class, + 2. Add the latent space streamlines for each bundle using "add_data_to_plot" + with its corresponding label. + 3. Fit and plot the t-SNE projection using the "plot" method. + + t-SNE projection can only leverage the fit_transform() with all the data that needs to + be projected at the same time since it aims to preserve the local structure of the data. + """ + def __init__(self, + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Parameters + ---------- + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: List + Random state for t-SNE. + max_subset_size: + In case of performance issues, you can limit the number of streamlines to plot + for each bundle. + """ + self.save_path = save_path + self.fig_size = fig_size + self.random_state = random_state + self.max_subset_size = max_subset_size + + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + + + def add_data_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + if isinstance(data, torch.Tensor): + latent_space_streamlines = data.cpu().numpy() + else: + latent_space_streamlines = data + + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > self.max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + self.bundles[label] = latent_space_streamlines + + def plot(self): + """ + Fit and plot the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using "add_data_to_plot". + """ + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) + + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + logging.info("Fitting TSNE projection.") + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if self.fig_size is not None: + fig.set_figheight(self.fig_size[0]) + fig.set_figwidth(self.fig_size[1]) + + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + ax.scatter( + proj_data[:, 0], + proj_data[:, 1], + label=bname, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + ) + + ax.legend() + + if self.save_path is not None: + fig.savefig(self.save_path) + else: + plt.show() + diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py new file mode 100644 index 00000000..c4b7c0a4 --- /dev/null +++ b/scripts_python/ae_visualize_bundles.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import pathlib +import torch +import numpy as np +from glob import glob +from os.path import expanduser +from dipy.tracking.streamline import set_number_of_points + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_bundles', + help="The 'glob' path to several bundles identified by their file name." + "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + +def load_bundles(p, args, files_list: list): + bundles = [] + for bundle_file in files_list: + bundle_sft = load_tractogram_with_reference(p, args, bundle_file) + bundle_sft.to_vox() + bundle_sft.to_corner() + bundles.append(bundle_sft) + return bundles + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, []) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + expanded = expanduser(args.in_bundles) + bundles_files = glob(expanded) + if isinstance(bundles_files, str): + bundles_files = [bundles_files] + + bundles_label = [pathlib.Path(l).stem for l in bundles_files] + bundles_sft = load_bundles(p, args, bundles_files) + + logging.info("Running model to compute loss") + + ls_viz = BundlesLatentSpaceVisualizer( + save_path="/home/local/USHERBROOKE/levj1404/Documents/dwi_ml/data/out.png" + ) + + with torch.no_grad(): + for i, bundle_sft in enumerate(bundles_sft): + + # Resample + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), + dtype=torch.float32, device=device) + + latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) + + ls_viz.plot() + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..79542aba 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -4,7 +4,7 @@ import logging import torch - +import numpy as np from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, add_reference_arg, @@ -14,6 +14,8 @@ from dwi_ml.io_utils import (add_arg_existing_experiment_path, add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer +from dipy.tracking.streamline import set_number_of_points def _build_arg_parser(): @@ -31,6 +33,8 @@ def _build_arg_parser(): p.add_argument('out_tractogram', help="If set, saves the tractogram with the loss per point " "as a data per point (color)") + p.add_argument('--viz_save_path', type=str, default=None, + help="Path to save the figure. If not specified, the figure will be shown.") # Options p.add_argument('--batch_size', type=int) @@ -60,23 +64,16 @@ def main(): assert_outputs_exist(p, args, args.out_tractogram) # Device - device = (torch.device('cuda') if torch.cuda.is_available() and - args.use_gpu else None) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) - # model.set_context('training') - # 2. Compute loss - # tester = TesterOneInput(args.experiment_path, - # model, - # args.batch_size, - # device) - # tester = Tester(args.experiment_path, model, args.batch_size, device) - # sft = tester.load_and_format_data(args.subj_id, - # args.hdf5_file, - # args.subset) + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + # Setup vizualisation + ls_viz = BundlesLatentSpaceVisualizer(save_path=args.viz_save_path) + model.register_hook_post_encoding(lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -89,24 +86,15 @@ def main(): save_tractogram(new_sft, 'orig_5000.trk') with torch.no_grad(): - streamlines = [ - torch.as_tensor(s, dtype=torch.float32, device=device) - for s in bundle] + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle, 256)), + dtype=torch.float32, device=device) tmp_outputs = model(streamlines) - # latent = model.encode(streamlines) - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + ls_viz.plot() - # print(streamlines_output[0].shape) + streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) - - # latent_output = [s.cpu().numpy() for s in latent] - - # outputs, losses = tester.run_model_on_sft( - # sft, uncompress_loss=args.uncompress_loss, - # force_compress_loss=args.force_compress_loss, - # weight_with_angle=args.weight_with_angle) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) if __name__ == '__main__': From c665cf2298030cea0a62d30ab0b6bd1af72ae42e Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 27 Sep 2024 12:16:36 -0400 Subject: [PATCH 02/28] Viz latent space each n epochs --- dwi_ml/viz/latent_streamlines.py | 138 +++++++++++++++---------------- scripts_python/ae_train_model.py | 40 +++++++++ 2 files changed, 105 insertions(+), 73 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 297bced6..7e260c3a 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -1,5 +1,5 @@ +import os import logging - from typing import Union, List, Tuple from sklearn.manifold import TSNE import numpy as np @@ -7,64 +7,7 @@ import matplotlib.pyplot as plt -def plot_latent_streamlines( - encoded_streamlines: Union[np.ndarray, torch.Tensor], - save_path: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None - ): - """ - Projects and plots the latent space representation - of the streamlines using t-SNE dimensionality reduction. - - Parameters - ---------- - encoded_streamlines: Union[np.ndarray, torch.Tensor] - Latent space streamlines to plot of shape (N, latent_space_dim). - save_path: str - Path to save the figure. If not specified, the figure will be shown. - fig_size: List[int] or Tuple[int] - 2-valued figure size (x, y) - random_state: int - Random state for t-SNE. - max_subset_size: int: - In case of performance issues, you can limit the number of streamlines to plot. - """ - - if isinstance(encoded_streamlines, torch.Tensor): - latent_space_streamlines = encoded_streamlines.cpu().numpy() - else: - latent_space_streamlines = encoded_streamlines - - if max_subset_size is not None: - if not (max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. - if (len(latent_space_streamlines) > max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - - # Project the data into 2 dimensions. - tsne = TSNE(n_components=2, random_state=random_state) - X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) - - - logging.info("New figure for t-SNE visualisation.") - fig, ax = plt.subplots() - if fig_size is not None: - fig.set_figheight(fig_size[0]) - fig.set_figwidth(fig_size[1]) - - ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) - - if save_path is not None: - fig.savefig(save_path) - else: - plt.show() - +LOGGER = logging.getLogger(__name__) class BundlesLatentSpaceVisualizer(object): """ @@ -79,10 +22,12 @@ class BundlesLatentSpaceVisualizer(object): be projected at the same time since it aims to preserve the local structure of the data. """ def __init__(self, - save_path: str = None, + save_dir: str = None, fig_size: Union[List, Tuple] = None, random_state: int = 42, - max_subset_size: int = None + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True ): """ Parameters @@ -93,18 +38,43 @@ def __init__(self, 2-valued figure size (x, y) random_state: List Random state for t-SNE. - max_subset_size: + max_subset_size: int In case of performance issues, you can limit the number of streamlines to plot for each bundle. + prefix_numbering: bool + If True, the saved figures will be numbered with the current plot number. + The plot number is incremented after each call to the "plot" method. + reset_warning: bool + If True, a warning will be displayed when calling "plot" several times + without calling "reset_data" in between to clear the data. """ - self.save_path = save_path + self.save_dir = save_dir + + # Make sure that self.save_dir is a directory and exists. + if self.save_dir is not None: + if not os.path.isdir(self.save_dir): + raise ValueError("The save_dir should be a directory.") + self.fig_size = fig_size self.random_state = random_state self.max_subset_size = max_subset_size + self.prefix_numbering = prefix_numbering + self.reset_warning = reset_warning + + self.current_plot_number = 0 + self.should_call_reset_before_plot = False self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} - + + def reset_data(self): + """ + Reset the data to plot. If you call plot several times without + calling this method, the data will be accumulated. + """ + # Not sure if resetting the TSNE object is necessary. + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} def add_data_to_plot(self, data: np.ndarray, label: str = '_'): """ @@ -119,7 +89,7 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): Name of the bundle. Used for the legend. """ if isinstance(data, torch.Tensor): - latent_space_streamlines = data.cpu().numpy() + latent_space_streamlines = data.detach().numpy() else: latent_space_streamlines = data @@ -135,13 +105,28 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines - def plot(self): + def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): """ Fit and plot the t-SNE projection of the latent space streamlines. This should be called once after adding all the data to plot using "add_data_to_plot". + + Parameters + ---------- + figure_name_prefix: str + Name of the figure to be saved. This is just the prefix of the full file + name as it will be suffixed with the current plot number if enabled. """ + if self.should_call_reset_before_plot and self.reset_warning: + LOGGER.warning("You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to unexpected results.") + self.should_call_reset_before_plot = False + elif not self.current_plot_number > 0: + # Only enable the flag for the first plot. + # So that the warning above is only displayed once. + self.should_call_reset_before_plot = True + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) + LOGGER.info("Plotting a total of {} streamlines".format(nb_streamlines)) bundles_indices = {} current_start = 0 @@ -153,11 +138,12 @@ def plot(self): all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - logging.info("Fitting TSNE projection.") + LOGGER.info("Fitting TSNE projection.") all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - logging.info("New figure for t-SNE visualisation.") + LOGGER.info("New figure for t-SNE visualisation.") fig, ax = plt.subplots() + ax.set_title(title) if self.fig_size is not None: fig.set_figheight(self.fig_size[0]) fig.set_figwidth(self.fig_size[1]) @@ -174,10 +160,16 @@ def plot(self): linewidths=0.5, ) - ax.legend() - - if self.save_path is not None: - fig.savefig(self.save_path) + if len(self.bundles) > 1: + ax.legend() + + if self.save_dir is not None: + filename = '{}_{}.png'.format(figure_name_prefix, self.current_plot_number) \ + if self.prefix_numbering else '{}.png'.format(figure_name_prefix) + filename = os.path.join(self.save_dir, filename) + fig.savefig(filename) else: plt.show() + self.current_plot_number += 1 + diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 20d416b1..9ea9eadc 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -27,6 +27,7 @@ from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) from dwi_ml.training.utils.trainer import add_training_args from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ @@ -43,6 +44,9 @@ def prepare_arg_parser(): add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") + p.add_argument('--viz_latent_space_freq', type=int, default=None, + help="Frequency at which to visualize latent space.\n" + "This is expressed in number of epochs.") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) @@ -116,6 +120,42 @@ def init_from_args(args, sub_loggers_level): logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) + if args.viz_latent_space_freq is not None: + # Setup to visualize latent space + save_dir = os.path.join(args.experiments_path, args.experiment_name, 'latent_space_plots') + os.makedirs(save_dir, exist_ok=True) + + ls_viz = BundlesLatentSpaceVisualizer(save_dir, + prefix_numbering=True, + max_subset_size=1000) + current_epoch = 0 + def visualize_latent_space(encoding): + """ + This is not a clean way to do it. This would require changes in the + trainer to allow for a callback system where we could register a + function to be called at the end of each epoch to plot the latent space + of the data accumulated during the epoch (at each batch). + + Also, using this method, the latent space of the last epoch will not be + plotted. We would need to calculate which batch step would be the last in + the epoch and then plot accordingly. + """ + nonlocal current_epoch, trainer, ls_viz + + # Only execute the following if we are in training + if not trainer.model.context == 'training': + return + + changed_epoch = current_epoch != trainer.current_epoch + if not changed_epoch: + ls_viz.add_data_to_plot(encoding) + elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: + current_epoch = trainer.current_epoch + ls_viz.plot(title="Latent space at epoch {}".format(current_epoch)) + ls_viz.reset_data() + ls_viz.add_data_to_plot(encoding) + model.register_hook_post_encoding(visualize_latent_space) + return trainer From 8df0c0f3a60ad6ef0d66e06f447c53b51c366ecf Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 27 Sep 2024 13:20:57 -0400 Subject: [PATCH 03/28] autopep8 pass --- dwi_ml/viz/latent_streamlines.py | 102 ++++++++++++--------- scripts_python/ae_train_model.py | 28 +++--- scripts_python/ae_visualize_bundles.py | 11 ++- scripts_python/ae_visualize_streamlines.py | 10 +- 4 files changed, 91 insertions(+), 60 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 7e260c3a..ddd95e5b 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -9,47 +9,52 @@ LOGGER = logging.getLogger(__name__) + class BundlesLatentSpaceVisualizer(object): """ - Utility class that wraps a t-SNE projection of the latent space for multiple bundles. - The usage of this class is intented as follows: + Utility class that wraps a t-SNE projection of the latent + space for multiple bundles. The usage of this class is + intented as follows: 1. Create an instance of this class, - 2. Add the latent space streamlines for each bundle using "add_data_to_plot" - with its corresponding label. + 2. Add the latent space streamlines for each bundle + using "add_data_to_plot" with its corresponding label. 3. Fit and plot the t-SNE projection using the "plot" method. - - t-SNE projection can only leverage the fit_transform() with all the data that needs to - be projected at the same time since it aims to preserve the local structure of the data. + + t-SNE projection can only leverage the fit_transform() with all + the data that needs to be projected at the same time since it aims + to preserve the local structure of the data. """ + def __init__(self, - save_dir: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None, - prefix_numbering: bool = False, - reset_warning: bool = True - ): + save_dir: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True + ): """ Parameters ---------- save_path: str - Path to save the figure. If not specified, the figure will be shown. + Path to save the figure. If not specified, the figure will show. fig_size: List[int] or Tuple[int] 2-valued figure size (x, y) random_state: List Random state for t-SNE. max_subset_size: int - In case of performance issues, you can limit the number of streamlines to plot - for each bundle. + In case of performance issues, you can limit the number of + streamlines to plot for each bundle. prefix_numbering: bool - If True, the saved figures will be numbered with the current plot number. - The plot number is incremented after each call to the "plot" method. + If True, the saved figures will be numbered with the current + plot number. The plot number is incremented after each call + to the "plot" method. reset_warning: bool - If True, a warning will be displayed when calling "plot" several times - without calling "reset_data" in between to clear the data. + If True, a warning will be displayed when calling "plot"several + times without calling "reset_data" in between to clear the data. """ self.save_dir = save_dir - + # Make sure that self.save_dir is a directory and exists. if self.save_dir is not None: if not os.path.isdir(self.save_dir): @@ -60,7 +65,7 @@ def __init__(self, self.max_subset_size = max_subset_size self.prefix_numbering = prefix_numbering self.reset_warning = reset_warning - + self.current_plot_number = 0 self.should_call_reset_before_plot = False @@ -92,33 +97,43 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): latent_space_streamlines = data.detach().numpy() else: latent_space_streamlines = data - + if self.max_subset_size is not None: if not (self.max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + # Only sample if we need to reduce the number of latent streamlines # to show on the plot. if (len(latent_space_streamlines) > self.max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - + sample_indices = np.random.choice( + len(latent_space_streamlines), + self.max_subset_size, replace=False) + + latent_space_streamlines = \ + latent_space_streamlines[sample_indices] + self.bundles[label] = latent_space_streamlines def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): """ Fit and plot the t-SNE projection of the latent space streamlines. - This should be called once after adding all the data to plot using "add_data_to_plot". - + This should be called once after adding all the data to plot using + "add_data_to_plot". + Parameters ---------- figure_name_prefix: str - Name of the figure to be saved. This is just the prefix of the full file - name as it will be suffixed with the current plot number if enabled. + Name of the figure to be saved. This is just the prefix of the + full file name as it will be suffixed with the current plot + number if enabled. """ if self.should_call_reset_before_plot and self.reset_warning: - LOGGER.warning("You plotted another time without resetting the data. " - "The data will be accumulated, which might lead to unexpected results.") + LOGGER.warning( + "You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to " + "unexpected results.") self.should_call_reset_before_plot = False elif not self.current_plot_number > 0: # Only enable the flag for the first plot. @@ -126,12 +141,14 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): self.should_call_reset_before_plot = True nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - LOGGER.info("Plotting a total of {} streamlines".format(nb_streamlines)) + LOGGER.info( + "Plotting a total of {} streamlines".format(nb_streamlines)) bundles_indices = {} current_start = 0 for (bname, bdata) in self.bundles.items(): - bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) + bundles_indices[bname] = np.arange( + current_start, current_start + bdata.shape[0]) current_start += bdata.shape[0] assert current_start == nb_streamlines @@ -159,17 +176,20 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): edgecolors='black', linewidths=0.5, ) - + if len(self.bundles) > 1: ax.legend() if self.save_dir is not None: - filename = '{}_{}.png'.format(figure_name_prefix, self.current_plot_number) \ - if self.prefix_numbering else '{}.png'.format(figure_name_prefix) + if self.prefix_numbering: + filename = '{}_{}.png'.format( + figure_name_prefix, self.current_plot_number) + else: + filename = '{}.png'.format(figure_name_prefix) + filename = os.path.join(self.save_dir, filename) fig.savefig(filename) else: plt.show() self.current_plot_number += 1 - diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 9ea9eadc..1252f27c 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -14,7 +14,8 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg +from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist, + add_verbose_arg) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str @@ -40,7 +41,7 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - #training_group = add_training_args(p) + # training_group = add_training_args(p) add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") @@ -51,7 +52,7 @@ def prepare_arg_parser(): add_verbose_arg(p) # Additional arg for projects - #training_group.add_argument( + # training_group.add_argument( # '--clip_grad', type=float, default=None, # help="Value to which the gradient norms to avoid exploding gradients." # "\nDefault = None (not clipping).") @@ -122,13 +123,15 @@ def init_from_args(args, sub_loggers_level): if args.viz_latent_space_freq is not None: # Setup to visualize latent space - save_dir = os.path.join(args.experiments_path, args.experiment_name, 'latent_space_plots') + save_dir = os.path.join(args.experiments_path, + args.experiment_name, 'latent_space_plots') os.makedirs(save_dir, exist_ok=True) ls_viz = BundlesLatentSpaceVisualizer(save_dir, - prefix_numbering=True, - max_subset_size=1000) + prefix_numbering=True, + max_subset_size=1000) current_epoch = 0 + def visualize_latent_space(encoding): """ This is not a clean way to do it. This would require changes in the @@ -151,7 +154,8 @@ def visualize_latent_space(encoding): ls_viz.add_data_to_plot(encoding) elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: current_epoch = trainer.current_epoch - ls_viz.plot(title="Latent space at epoch {}".format(current_epoch)) + ls_viz.plot(title="Latent space at epoch {}".format( + current_epoch)) ls_viz.reset_data() ls_viz.add_data_to_plot(encoding) model.register_hook_post_encoding(visualize_latent_space) @@ -175,10 +179,12 @@ def main(): assert_outputs_exist(p, args, args.experiments_path) # Verify if a checkpoint has been saved. Else create an experiment. - if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, - "checkpoint")): - raise FileExistsError("This experiment already exists. Delete or use " - "script l2t_resume_training_from_checkpoint.py.") + if os.path.exists(os.path.join( + args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError( + "This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") trainer = init_from_args(args, sub_loggers_level) diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py index c4b7c0a4..2e7aad1e 100644 --- a/scripts_python/ae_visualize_bundles.py +++ b/scripts_python/ae_visualize_bundles.py @@ -15,7 +15,7 @@ add_verbose_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer @@ -42,6 +42,7 @@ def _build_arg_parser(): add_verbose_arg(p) return p + def load_bundles(p, args, files_list: list): bundles = [] for bundle_file in files_list: @@ -51,6 +52,7 @@ def load_bundles(p, args, files_list: list): bundles.append(bundle_sft) return bundles + def main(): p = _build_arg_parser() args = p.parse_args() @@ -91,12 +93,13 @@ def main(): with torch.no_grad(): for i, bundle_sft in enumerate(bundles_sft): - + # Resample streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), dtype=torch.float32, device=device) - - latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + + latent_streamlines = model.encode( + streamlines).cpu().numpy() # output of (N, 32) ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) ls_viz.plot() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 79542aba..69a91fb3 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -12,7 +12,7 @@ from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer from dipy.tracking.streamline import set_number_of_points @@ -73,7 +73,8 @@ def main(): # Setup vizualisation ls_viz = BundlesLatentSpaceVisualizer(save_path=args.viz_save_path) - model.register_hook_post_encoding(lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) + model.register_hook_post_encoding( + lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -87,12 +88,13 @@ def main(): with torch.no_grad(): streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle, 256)), - dtype=torch.float32, device=device) + dtype=torch.float32, device=device) tmp_outputs = model(streamlines) ls_viz.plot() - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + streamlines_output = [tmp_outputs[i, :, :].transpose( + 0, 1).cpu().numpy() for i in range(len(bundle))] new_sft = sft.from_sft(streamlines_output, sft) save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) From cb64063fbcddf2b4fb978c222c17e095df14b933 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Wed, 2 Oct 2024 10:52:27 -0400 Subject: [PATCH 04/28] Fix to cpu --- dwi_ml/viz/latent_streamlines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index ddd95e5b..6565eae6 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -94,7 +94,7 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): Name of the bundle. Used for the legend. """ if isinstance(data, torch.Tensor): - latent_space_streamlines = data.detach().numpy() + latent_space_streamlines = data.detach().cpu().numpy() else: latent_space_streamlines = data From f0973ffdaf596b7fe28bc47bffd7b31395062abf Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Wed, 2 Oct 2024 18:09:13 -0400 Subject: [PATCH 05/28] Use bundle index within HDF5 for coloring latent space --- .../data/dataset/multi_subject_containers.py | 28 +++--- .../data/dataset/single_subject_containers.py | 30 ++++--- dwi_ml/data/dataset/streamline_containers.py | 45 ++++++++-- dwi_ml/data/dataset/utils.py | 5 +- dwi_ml/models/projects/ae_models.py | 3 +- dwi_ml/training/batch_loaders.py | 89 ++++++++++++++++++- dwi_ml/training/trainers.py | 58 ++++++++---- dwi_ml/viz/latent_streamlines.py | 84 ++++++++++++----- scripts_python/ae_train_model.py | 47 ++++++---- scripts_python/ae_visualize_bundles.py | 6 +- 10 files changed, 303 insertions(+), 92 deletions(-) diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index 0c6bded5..d0fd9a6c 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -37,8 +37,9 @@ class MultisubjectSubset(Dataset): Based on torch's dataset class. Provides functions for a DataLoader to iterate over data and process batches. """ + def __init__(self, set_name: str, hdf5_file: str, lazy: bool, - cache_size: int = 0): + cache_size: int = 0, related_data_key: str = None): self.set_name = set_name self.hdf5_file = hdf5_file @@ -79,6 +80,7 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool, # This is only used in the lazy case. self.cache_size = cache_size self.volume_cache_manager = None + self.related_data_key = related_data_key def close_all_handles(self): if self.subjs_data_list.hdf_handle: @@ -278,7 +280,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None): # calling this method. logger.debug(" Creating subject '{}'.".format(subj_id)) subj_data = self._init_subj_from_hdf( - hdf_handle, subj_id, ref_group_info) + hdf_handle, subj_id, ref_group_info, related_data_key=self.related_data_key) # Add subject to the list subj_idx = self.subjs_data_list.add_subject(subj_data) @@ -335,13 +337,13 @@ def _build_empty_data_list(self): else: return SubjectsDataList(self.hdf5_file, logger) - def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info): + def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info, related_data_key=None): if self.is_lazy: return LazySubjectData.init_single_subject_from_hdf( - subject_id, hdf_handle, ref_group_info) + subject_id, hdf_handle, ref_group_info, related_data_key=related_data_key) else: return SubjectData.init_single_subject_from_hdf( - subject_id, hdf_handle, ref_group_info) + subject_id, hdf_handle, ref_group_info, related_data_key=related_data_key) class MultiSubjectDataset: @@ -356,8 +358,10 @@ class MultiSubjectDataset: datasets 'streamlines/data', 'streamlines/offsets', 'streamlines/lengths', 'streamlines/euclidean_lengths'. """ + def __init__(self, hdf5_file: str, lazy: bool, - cache_size: int = 0, log_level=None): + cache_size: int = 0, log_level=None, + related_data_key=None): """ Params ------ @@ -393,11 +397,11 @@ def __init__(self, hdf5_file: str, lazy: bool, # Preparing the testing set and validation set # In non-lazy data, the cache_size is not used. self.training_set = MultisubjectSubset( - 'training', hdf5_file, self.is_lazy, cache_size) + 'training', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) self.validation_set = MultisubjectSubset( - 'validation', hdf5_file, self.is_lazy, cache_size) + 'validation', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) self.testing_set = MultisubjectSubset( - 'testing', hdf5_file, self.is_lazy, cache_size) + 'testing', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) @property def params_for_checkpoint(self) -> Dict[str, Any]: @@ -478,7 +482,8 @@ def load_data(self, load_training=True, load_validation=True, self.nb_features = nb_features if streamline_groups is not None: - missing_str = np.setdiff1d(streamline_groups, poss_strea_groups) + missing_str = np.setdiff1d( + streamline_groups, poss_strea_groups) if len(missing_str) > 0: raise ValueError("Streamlines {} were not found in the " "first subject of your hdf5 file." @@ -498,7 +503,8 @@ def load_data(self, load_training=True, load_validation=True, self.streamline_groups, self.streamlines_contain_connectivity) self.training_set.set_subset_info(*group_info, step_size, compress) - self.validation_set.set_subset_info(*group_info, step_size, compress) + self.validation_set.set_subset_info( + *group_info, step_size, compress) self.testing_set.set_subset_info(*group_info, step_size, compress) # LOADING diff --git a/dwi_ml/data/dataset/single_subject_containers.py b/dwi_ml/data/dataset/single_subject_containers.py index fbd9b6cc..fa3639c0 100644 --- a/dwi_ml/data/dataset/single_subject_containers.py +++ b/dwi_ml/data/dataset/single_subject_containers.py @@ -17,8 +17,9 @@ class SubjectDataAbstract(object): single MRI acquisition. It could contain data from many "real" MRI volumes concatenated together. """ + def __init__(self, volume_groups: List[str], nb_features: List[int], - streamline_groups: List[str], subject_id: str): + streamline_groups: List[str], subject_id: str, related_data_key: str = None): """ Parameters ---------- @@ -37,6 +38,7 @@ def __init__(self, volume_groups: List[str], nb_features: List[int], self.streamline_groups = streamline_groups self.subject_id = subject_id self.is_lazy = None + self.related_data_key = related_data_key @property def mri_data_list(self) -> List[MRIDataAbstract]: @@ -50,7 +52,7 @@ def sft_data_list(self): @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None): + cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): """Returns an instance of this class, initiated by sending only the hdf handle. The child class's method will define how to load the data based on the child data management.""" @@ -64,10 +66,12 @@ def add_handle(self, hdf_handle): class SubjectData(SubjectDataAbstract): """Non-lazy version""" + def __init__(self, subject_id: str, volume_groups: List[str], nb_features: List[int], mri_data_list: List[MRIData] = None, streamline_groups: List[str] = None, - sft_data_list: List[SFTData] = None): + sft_data_list: List[SFTData] = None, + related_data_key: str = None): """ Additional params compared to super: ---- @@ -76,7 +80,7 @@ def __init__(self, subject_id: str, volume_groups: List[str], ._data, ._offsets, ._lengths, ._lengths_mm. """ super().__init__(volume_groups, nb_features, streamline_groups, - subject_id) + subject_id, related_data_key=related_data_key) self._mri_data_list = mri_data_list self._sft_data_list = sft_data_list self.is_lazy = False @@ -91,7 +95,7 @@ def sft_data_list(self): @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None): + cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): """ Instantiating a single subject data: load info and use __init__ """ @@ -113,12 +117,13 @@ def init_single_subject_from_hdf( logger.debug(" Loading streamlines group '{}'" .format(group)) sft_data = SFTData.init_sft_data_from_hdf_info( - hdf_file[subject_id][group]) + hdf_file[subject_id][group], related_data_key=related_data_key) subject_sft_data_list.append(sft_data) subj_data = cls(subject_id, volume_groups, nb_features, subject_mri_data_list, - streamline_groups, subject_sft_data_list) + streamline_groups, subject_sft_data_list, + related_data_key=related_data_key) return subj_data @@ -130,9 +135,10 @@ class LazySubjectData(SubjectDataAbstract): """ Lazy version. """ + def __init__(self, volume_groups: List[str], nb_features: List[int], streamline_groups: List[str], subject_id: str, - hdf_handle=None): + hdf_handle=None, related_data_key: str = None): """ Additional params compared to super: ------ @@ -140,13 +146,13 @@ def __init__(self, volume_groups: List[str], nb_features: List[int], Opened hdf file, if any. If None, data loading is deactivated. """ super().__init__(volume_groups, nb_features, streamline_groups, - subject_id) + subject_id, related_data_key=related_data_key) self.hdf_handle = hdf_handle self.is_lazy = True @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None): + cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): """ Instantiating a single subject data: NOT LOADING info and use __init__ (so in short: this does basically nothing, the lazy data is kept @@ -168,7 +174,7 @@ def init_single_subject_from_hdf( logger.debug(' Lazy: not loading data.') return cls(volume_groups, nb_features, streamline_groups, subject_id, - hdf_handle=None) + hdf_handle=None, related_data_key=related_data_key) @property def mri_data_list(self) -> Union[List[LazyMRIData], None]: @@ -202,7 +208,7 @@ def sft_data_list(self) -> Union[List[LazySFTData], None]: for group in self.streamline_groups: hdf_group = self.hdf_handle[self.subject_id][group] sft_data_list.append( - LazySFTData.init_sft_data_from_hdf_info(hdf_group)) + LazySFTData.init_sft_data_from_hdf_info(hdf_group, related_data_key=self.related_data_key)) return sft_data_list else: diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index d8c370ac..599f8d34 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -67,6 +67,21 @@ def _load_connectivity_info(hdf_group: h5py.Group): return contains_connectivity, connectivity_nb_blocs, connectivity_labels +def _load_related_data(hdf_group, related_data_key: str) -> Union[np.ndarray, None]: + related_data = None + # Load related data key if specified + if related_data_key is not None: + # Make sure the related data key is in the hdf5 group + if not (related_data_key in hdf_group.keys()): + raise KeyError("The key '{}' is not in the hdf5 group." + .format(related_data_key)) + + # Load the related data + related_data = np.array(hdf_group[related_data_key]).squeeze(1) + + return related_data + + class _LazyStreamlinesGetter(object): def __init__(self, hdf_group): self.hdf_group = hdf_group @@ -160,10 +175,12 @@ class SFTDataAbstract(object): all information necessary to treat with streamlines: the data itself and _offset, _lengths, space attributes, etc. """ + def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, contains_connectivity: bool, connectivity_nb_blocs: List = None, - connectivity_labels: np.ndarray = None): + connectivity_labels: np.ndarray = None, + related_data: np.ndarray = None): """ The lazy/non-lazy versions will have more parameters, such as the streamlines, the connectivity_matrix. In the case of the lazy version, @@ -194,6 +211,7 @@ def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, self.contains_connectivity = contains_connectivity self.connectivity_nb_blocs = connectivity_nb_blocs self.connectivity_labels = connectivity_labels + self.related_data = related_data def __len__(self): raise NotImplementedError @@ -245,6 +263,14 @@ def _get_streamlines_as_list(self, streamline_ids) -> List[ArraySequence]: the hdf5.""" raise NotImplementedError + def get_related_data(self, + streamline_ids: Union[List[int], int, slice, None] = None): + """Returns the data related to the streamlines.""" + if self.related_data is None: + return None + else: + return self.related_data[streamline_ids] + def as_sft(self, streamline_ids: Union[List[int], int, slice, None] = None) \ -> StatefulTractogram: @@ -301,7 +327,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self._connectivity_matrix @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: str = None): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. @@ -318,7 +344,10 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): else: connectivity_matrix = None - space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group) + related_data = _load_related_data(hdf_group, related_data_key) + + space_attributes, space, origin = _load_space_attributes_from_hdf( + hdf_group) # Return an instance of SubjectMRIData instantiated through __init__ # with this loaded data: @@ -328,7 +357,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space=space, origin=origin, contains_connectivity=contains_connectivity, connectivity_nb_blocs=connectivity_nb_blocs, - connectivity_labels=connectivity_labels) + connectivity_labels=connectivity_labels, + related_data=related_data) def _get_streamlines_as_list(self, streamline_ids): if streamline_ids is not None: @@ -367,7 +397,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self.streamlines_getter.connectivity_matrix(indxyz) @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: str = None): space_attributes, space, origin = _load_space_attributes_from_hdf( hdf_group) @@ -376,12 +406,15 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): streamlines = _LazyStreamlinesGetter(hdf_group) + related_data = _load_related_data(hdf_group, related_data_key) + return cls(streamlines_getter=streamlines, space_attributes=space_attributes, space=space, origin=origin, contains_connectivity=contains_connectivity, connectivity_nb_blocs=connectivity_nb_blocs, - connectivity_labels=connectivity_labels) + connectivity_labels=connectivity_labels, + related_data=related_data) def _get_streamlines_as_list(self, streamline_ids): streamlines = self.streamlines_getter.get_array_sequence( diff --git a/dwi_ml/data/dataset/utils.py b/dwi_ml/data/dataset/utils.py index d8a7b0e9..fdea92bb 100644 --- a/dwi_ml/data/dataset/utils.py +++ b/dwi_ml/data/dataset/utils.py @@ -7,7 +7,8 @@ def prepare_multisubjectdataset(args, load_training=True, load_validation=True, load_testing=True, - log_level=logging.root.level): + log_level=logging.root.level, + related_data_key=None): """ Instantiates a MultiSubjectDataset AND loads data. @@ -19,7 +20,7 @@ def prepare_multisubjectdataset(args, load_training=True, load_validation=True, with Timer("\nPreparing datasets", newline=True, color='blue'): dataset = MultiSubjectDataset( args.hdf5_file, lazy=args.lazy, cache_size=args.cache_size, - log_level=log_level) + log_level=log_level, related_data_key=related_data_key) dataset.load_data(load_training, load_validation, load_testing) logging.info("Number of subjects loaded: \n" diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index acea66d6..8608f339 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -106,6 +106,7 @@ def pre_pad(m): def forward(self, input_streamlines: List[torch.tensor], + related_data=None ): """Run the model on a batch of sequences. @@ -126,7 +127,7 @@ def forward(self, encoded = self.encode(input_streamlines) for hook in self.post_encoding_hooks: - hook(encoded) + hook(encoded, related_data) x = self.decode(encoded) return x diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 623371a9..275e5273 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -157,7 +157,7 @@ def params_for_checkpoint(self): 'noise_gaussian_size_forward': self.noise_gaussian_size_forward, 'noise_gaussian_size_loss': self.noise_gaussian_size_loss, 'reverse_ratio': self.reverse_ratio, - 'split_ratio': self.split_ratio, + 'split_ratio': self.split_ratio } return params @@ -326,6 +326,89 @@ def load_batch_streamlines( return batch_streamlines, final_s_ids_per_subj + def load_batch_streamlines_and_related( + self, streamline_ids_per_subj: List[Tuple[int, list]]): + """ + Fetches the chosen streamlines for all subjects in batch. + Pocesses data augmentation. + + Torch uses this function to process the data with the dataloader + parallel workers (on cpu). To be used as collate_fn. + + Parameters + ---------- + streamline_ids_per_subj: List[Tuple[int, list]] + The list of streamline ids for each subject (relative ids inside + each subject's tractogram) for this batch. + + Returns + ------- + (batch_streamlines, final_s_ids_per_subj) + Where + - batch_streamlines: list[torch.tensor] + The new streamlines after data augmentation, IN VOXEL SPACE, + CORNER. + - final_s_ids_per_subj: Dict[int, slice] + The new streamline ids per subj in this augmented batch. + """ + if self.context is None: + raise ValueError("Context must be set prior to using the batch " + "loader.") + + # The batch's streamline ids will change throughout processing because + # of data augmentation, so we need to do it subject by subject to + # keep track of the streamline ids. These final ids will correspond to + # the loaded, processed streamlines, not to the ids in the hdf5 file. + final_s_ids_per_subj = defaultdict(slice) + batch_streamlines = [] + streamlines_related_data = [] + for subj, s_ids in streamline_ids_per_subj: + logger.debug( + " Data loader: Processing data preparation for " + "subj {} (preparing {} streamlines)".format(subj, len(s_ids))) + + # No cache for the sft data. Accessing it directly. + # Note: If this is used through the dataloader, multiprocessing + # is used. Each process will open a handle. + subj_data = \ + self.context_subset.subjs_data_list.get_subj_with_handle(subj) + subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx] + + # Get streamlines as sft + logger.debug(" Loading sampled streamlines...") + sft = subj_sft_data.as_sft(s_ids) + + # TODO: modify this list consequently to the data augmentations. + # Currently, if the data augmentation adds/removes streamlines, + # the related data won't match the streamlines list anymore. + related_data = subj_sft_data.get_related_data( + s_ids) # Can return None + sft = self._data_augmentation_sft(sft) + + # Remember the indices of this subject's (augmented) streamlines + ids_start = len(batch_streamlines) + ids_end = ids_start + len(sft) + final_s_ids_per_subj[subj] = slice(ids_start, ids_end) + + # Add all (augmented) streamlines to the batch + # What we want is the streamline coordinates, to eventually get + # the underlying input(s). Sending to vox and to corner to + # be able to use our trilinear interpolation + sft.to_vox() + sft.to_corner() + batch_streamlines.extend(sft.streamlines) + + if related_data is not None: + streamlines_related_data.extend(related_data) + + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] + + if len(streamlines_related_data) > 0: + assert len(streamlines_related_data) == len(batch_streamlines), \ + "Related data should have the same length as the streamlines." + + return batch_streamlines, final_s_ids_per_subj, streamlines_related_data + def load_batch_connectivity_matrices( self, streamline_ids_per_subj: Dict[int, slice]): if not self.data_contains_connectivity: @@ -447,8 +530,8 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor], # because in load_batch, we use sft.to_vox and sft.to_corner # before adding streamline to batch. subbatch_x_data = self.model.prepare_batch_one_input( - streamlines, self.context_subset, subj, - self.input_group_idx) + streamlines, self.context_subset, subj, + self.input_group_idx) batch_x_data.extend(subbatch_x_data) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 15621360..f89aeb20 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -50,6 +50,7 @@ class DWIMLAbstractTrainer: NOTE: TRAINER USES STREAMLINES COORDINATES IN VOXEL SPACE, CORNER ORIGIN. """ + def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, @@ -63,7 +64,8 @@ def __init__(self, nb_cpu_processes: int = 0, use_gpu: bool = False, clip_grad: float = None, comet_workspace: str = None, comet_project: str = None, - from_checkpoint: bool = False, log_level=logging.root.level): + from_checkpoint: bool = False, log_level=logging.root.level, + related_data_retrieval: bool = False): """ Parameters ---------- @@ -228,12 +230,15 @@ def __init__(self, # dataloader output is on GPU, ready to be fed to the model. # Otherwise, dataloader output is kept on CPU, and the main thread # sends volumes and coords on GPU for interpolation. + self.related_data_retrieval = related_data_retrieval + self.collate_fn = self.batch_loader.load_batch_streamlines_and_related \ + if self.related_data_retrieval else self.batch_loader.load_batch_streamlines logger.debug("- Instantiating dataloaders...") self.train_dataloader = DataLoader( dataset=self.batch_sampler.dataset.training_set, batch_sampler=self.batch_sampler, num_workers=self.nb_cpu_processes, - collate_fn=self.batch_loader.load_batch_streamlines, + collate_fn=self.collate_fn, pin_memory=self.use_gpu) self.valid_dataloader = None if self.use_validation: @@ -241,7 +246,7 @@ def __init__(self, dataset=self.batch_sampler.dataset.validation_set, batch_sampler=self.batch_sampler, num_workers=self.nb_cpu_processes, - collate_fn=self.batch_loader.load_batch_streamlines, + collate_fn=self.collate_fn, pin_memory=self.use_gpu) # ---------------------- @@ -517,8 +522,10 @@ def _update_states_from_checkpoint(self, current_states): # A. Rng value. # RNG: # - numpy - self.batch_sampler.np_rng.set_state(current_states['sampler_np_rng_state']) - self.batch_loader.np_rng.set_state(current_states['loader_np_rng_state']) + self.batch_sampler.np_rng.set_state( + current_states['sampler_np_rng_state']) + self.batch_loader.np_rng.set_state( + current_states['loader_np_rng_state']) # - torch torch.set_rng_state(current_states['torch_rng_state']) if self.use_gpu: @@ -568,10 +575,13 @@ def _init_comet(self): display_summary_level=False) self.comet_exp.set_name(self.experiment_name) self.comet_exp.log_parameters(self.params_for_checkpoint) - self.comet_exp.log_parameters(self.batch_sampler.params_for_checkpoint) - self.comet_exp.log_parameters(self.batch_loader.params_for_checkpoint) + self.comet_exp.log_parameters( + self.batch_sampler.params_for_checkpoint) + self.comet_exp.log_parameters( + self.batch_loader.params_for_checkpoint) self.comet_exp.log_parameters(self.model.params_for_checkpoint) - self.comet_exp.log_parameters(self.model.computed_params_for_display) + self.comet_exp.log_parameters( + self.model.computed_params_for_display) self.comet_key = self.comet_exp.get_key() # Couldn't find how to set log level. Getting it directly. comet_log = logging.getLogger("comet_ml") @@ -814,10 +824,17 @@ def train_one_epoch(self, epoch): train_iterator = enumerate(pbar) for batch_id, data in train_iterator: + if self.related_data_retrieval: + related_data = data[2] + data = data[:2] + else: + related_data = None + # Enable gradients for backpropagation. Uses torch's module # train(), which "turns on" the training mode. with grad_context(): - mean_loss = self.train_one_batch(data) + mean_loss = self.train_one_batch( + data, related_data=related_data) unclipped_grad_norm, grad_norm = self.back_propagation( mean_loss) @@ -880,9 +897,16 @@ def validate_one_epoch(self, epoch): valid_iterator = enumerate(pbar) for batch_id, data in valid_iterator: + if self.related_data_retrieval: + related_data = data[2] + data = data[:2] + else: + related_data = None + # Validate this batch: forward propagation + loss with torch.no_grad(): - self.validate_one_batch(data, epoch) + self.validate_one_batch( + data, epoch, related_data=related_data) # Break if maximum number of epochs has been reached if batch_id == self.nb_batches_valid - 1: @@ -908,23 +932,25 @@ def validate_one_epoch(self, epoch): monitor.end_epoch() self._update_comet_after_epoch('validation', epoch) - def train_one_batch(self, data): + def train_one_batch(self, data, **model_kwargs): """ Computes the loss for the current batch and updates monitors. Returns the loss to be used for backpropagation. """ # Encapsulated for easier management of child classes. - mean_local_loss, n = self.run_one_batch(data) + mean_local_loss, n = self.run_one_batch( + data, **model_kwargs) # mean loss is a Tensor of a single value. item() converts to float self.train_loss_monitor.update(mean_local_loss.cpu().item(), weight=n) return mean_local_loss - def validate_one_batch(self, data, epoch): + def validate_one_batch(self, data, epoch, **model_kwargs): """ Computes the loss(es) for the current batch and updates monitors. """ - mean_local_loss, n = self.run_one_batch(data) + mean_local_loss, n = self.run_one_batch( + data, **model_kwargs) self.valid_local_loss_monitor.update(mean_local_loss.cpu().item(), weight=n) @@ -993,7 +1019,7 @@ def _save_best_model(self): json_file.write(json.dumps(best_losses, indent=4, separators=(',', ': '))) - def run_one_batch(self, data): + def run_one_batch(self, data, **model_kwargs): """ Runs a batch of data through the model (calling its forward method) and returns the mean loss. @@ -1037,7 +1063,7 @@ def run_one_batch(self, data): # but ok, shouldn't be too heavy. Easier to deal with multiple # projects' requirements by sending whole streamlines rather # than only directions. - model_outputs = self.model(streamlines_f) + model_outputs = self.model(streamlines_f, **model_kwargs) del streamlines_f logger.debug('*** Computing loss') diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 6565eae6..30732605 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -63,6 +63,11 @@ def __init__(self, self.fig_size = fig_size self.random_state = random_state self.max_subset_size = max_subset_size + if not (self.max_subset_size > 0): + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + self.prefix_numbering = prefix_numbering self.reset_warning = reset_warning @@ -81,7 +86,7 @@ def reset_data(self): self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} - def add_data_to_plot(self, data: np.ndarray, label: str = '_'): + def add_data_to_plot(self, data: np.ndarray, labels: List[str]): """ Add unprojected data (no t-SNE, no PCA, etc.). This should be directly the output of the model as a numpy array. @@ -90,29 +95,33 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): ---------- data: str Unprojected latent space streamlines (N, latent_space_dim). - label: str - Name of the bundle. Used for the legend. + label: np.ndarray + Labels for each streamline. """ - if isinstance(data, torch.Tensor): - latent_space_streamlines = data.detach().cpu().numpy() - else: - latent_space_streamlines = data + latent_space_streamlines = self._to_numpy(data) - if self.max_subset_size is not None: - if not (self.max_subset_size > 0): - raise ValueError( - "A max_subset_size of an integer value greater" - "than 0 is required.") + all_labels = np.unique(labels) + for label in all_labels: + label_indices = labels == label + label_data = latent_space_streamlines[label_indices] + label_data = self._resample_max_subset_size(label_data) + self.bundles[label] = label_data - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. - if (len(latent_space_streamlines) > self.max_subset_size): - sample_indices = np.random.choice( - len(latent_space_streamlines), - self.max_subset_size, replace=False) + def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. - latent_space_streamlines = \ - latent_space_streamlines[sample_indices] + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + latent_space_streamlines = self._to_numpy(data) + latent_space_streamlines = self._resample_max_subset_size( + latent_space_streamlines) self.bundles[label] = latent_space_streamlines @@ -140,10 +149,17 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): # So that the warning above is only displayed once. self.should_call_reset_before_plot = True + # Start by making sure the number of streamlines doesn't exceed the threshold. + for (bname, bdata) in self.bundles.items(): + if bdata.shape[0] > self.max_subset_size: + self.bundles[bname] = self._resample_max_subset_size(bdata) + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) LOGGER.info( "Plotting a total of {} streamlines".format(nb_streamlines)) + # Build the indices for each bundle to recover the streamlines after + # the t-SNE projection. bundles_indices = {} current_start = 0 for (bname, bdata) in self.bundles.items(): @@ -193,3 +209,31 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): plt.show() self.current_plot_number += 1 + + def _to_numpy(self, data): + if isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + else: + return data + + def _resample_max_subset_size(self, data: np.ndarray): + """ + Resample the data to the max_subset_size. + """ + _resampled = data + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(data) > self.max_subset_size): + sample_indices = np.random.choice( + len(data), + self.max_subset_size, replace=False) + + _resampled = data[sample_indices] + + return _resampled diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index cf16515f..88241ba7 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -46,6 +46,8 @@ def prepare_arg_parser(): p.add_argument('--viz_latent_space_freq', type=int, default=None, help="Frequency at which to visualize latent space.\n" "This is expressed in number of epochs.") + p.add_argument('--color_by', type=str, default=None, choices=['dps_bundle_index'], + help="Name of the group in hdf5 to color by.") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) @@ -55,9 +57,13 @@ def prepare_arg_parser(): def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed + viz_latent_space_freq = args.viz_latent_space_freq + color_by = args.color_by + # Prepare the dataset dataset = prepare_multisubjectdataset(args, load_testing=False, - log_level=sub_loggers_level) + log_level=sub_loggers_level, + related_data_key=color_by) # Preparing the model # (Direction getter) @@ -107,11 +113,12 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=sub_loggers_level) + log_level=sub_loggers_level, + related_data_retrieval=color_by is not None) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) - if args.viz_latent_space_freq is not None: + if viz_latent_space_freq is not None: # Setup to visualize latent space save_dir = os.path.join(args.experiments_path, args.experiment_name, 'latent_space_plots') @@ -122,16 +129,17 @@ def init_from_args(args, sub_loggers_level): max_subset_size=1000) current_epoch = 0 - def visualize_latent_space(encoding): + def visualize_latent_space(encoding, related_data): """ - This is not a clean way to do it. This would require changes in the - trainer to allow for a callback system where we could register a - function to be called at the end of each epoch to plot the latent space - of the data accumulated during the epoch (at each batch). - - Also, using this method, the latent space of the last epoch will not be - plotted. We would need to calculate which batch step would be the last in - the epoch and then plot accordingly. + This is not a clean way to do it. This would require changes in + the trainer to allow for a callback system where we could + register a function to be called at the end of each epoch to + plot the latent space of the data accumulated during the epoch + (at each batch). + + Also, using this method, the latent space of the last epoch will + not be plotted. We would need to calculate which batch step would + be the last in the epoch and then plot accordingly. """ nonlocal current_epoch, trainer, ls_viz @@ -141,13 +149,14 @@ def visualize_latent_space(encoding): changed_epoch = current_epoch != trainer.current_epoch if not changed_epoch: - ls_viz.add_data_to_plot(encoding) - elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: + ls_viz.add_data_to_plot(encoding, labels=related_data) + elif changed_epoch \ + and trainer.current_epoch % viz_latent_space_freq == 0: current_epoch = trainer.current_epoch ls_viz.plot(title="Latent space at epoch {}".format( current_epoch)) ls_viz.reset_data() - ls_viz.add_data_to_plot(encoding) + ls_viz.add_data_to_plot(encoding, labels=related_data) model.register_hook_post_encoding(visualize_latent_space) return trainer @@ -169,10 +178,10 @@ def main(): assert_outputs_exist(p, args, args.experiments_path) # Verify if a checkpoint has been saved. Else create an experiment. - if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, - "checkpoint")): - raise FileExistsError("This experiment already exists. Delete or use " - "script ae_resume_training_from_checkpoint.py.") + # if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + # "checkpoint")): + # raise FileExistsError("This experiment already exists. Delete or use " + # "script ae_resume_training_from_checkpoint.py.") trainer = init_from_args(args, sub_loggers_level) diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py index 2e7aad1e..68863d01 100644 --- a/scripts_python/ae_visualize_bundles.py +++ b/scripts_python/ae_visualize_bundles.py @@ -29,7 +29,8 @@ def _build_arg_parser(): # Add_args_testing_subj_hdf5(p) p.add_argument('in_bundles', - help="The 'glob' path to several bundles identified by their file name." + help="The 'glob' path to several bundles identified " + "by their file name." "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") # Options @@ -75,7 +76,8 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + args.experiment_path + '/best_model', log_level=sub_loggers_level) + model = model.to(device) expanded = expanduser(args.in_bundles) bundles_files = glob(expanded) From 7f41de671c26669f289e4603496664eb896ea82c Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Wed, 2 Oct 2024 18:09:44 -0400 Subject: [PATCH 06/28] Subplots with best epoch: part I --- dwi_ml/viz/latent_streamlines.py | 115 +++++++++++++++++++++---------- scripts_python/ae_train_model.py | 31 +++++++-- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 30732605..9fef1e63 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -26,12 +26,13 @@ class BundlesLatentSpaceVisualizer(object): """ def __init__(self, - save_dir: str = None, - fig_size: Union[List, Tuple] = None, + save_dir: str, + fig_size: Union[List, Tuple] = (16, 8), random_state: int = 42, max_subset_size: int = None, prefix_numbering: bool = False, - reset_warning: bool = True + reset_warning: bool = True, + bundle_mapping: dict = None ): """ Parameters @@ -56,9 +57,8 @@ def __init__(self, self.save_dir = save_dir # Make sure that self.save_dir is a directory and exists. - if self.save_dir is not None: - if not os.path.isdir(self.save_dir): - raise ValueError("The save_dir should be a directory.") + if not os.path.isdir(self.save_dir): + raise ValueError("The save_dir should be a directory.") self.fig_size = fig_size self.random_state = random_state @@ -70,6 +70,7 @@ def __init__(self, self.prefix_numbering = prefix_numbering self.reset_warning = reset_warning + self.bundle_mapping = bundle_mapping self.current_plot_number = 0 self.should_call_reset_before_plot = False @@ -77,6 +78,9 @@ def __init__(self, self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} + self.fig, self.axes = None, None + self.best_epoch = -1 + def reset_data(self): """ Reset the data to plot. If you call plot several times without @@ -125,7 +129,7 @@ def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines - def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): + def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int = -1): """ Fit and plot the t-SNE projection of the latent space streamlines. This should be called once after adding all the data to plot using @@ -174,42 +178,83 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): LOGGER.info("Fitting TSNE projection.") all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - LOGGER.info("New figure for t-SNE visualisation.") - fig, ax = plt.subplots() - ax.set_title(title) - if self.fig_size is not None: - fig.set_figheight(self.fig_size[0]) - fig.set_figwidth(self.fig_size[1]) + if self.fig is None or self.axes is None: + self.fig, self.axes = self._init_figure() + + # Check if we have a new best epoch. + # If so, that means we have to update the plot on the left. + is_new_best = best_epoch > self.best_epoch + if is_new_best: + self.best_epoch = best_epoch + + self._clear_figures(is_new_best) for (bname, bdata) in self.bundles.items(): bindices = bundles_indices[bname] proj_data = all_projected_streamlines[bindices] - ax.scatter( - proj_data[:, 0], - proj_data[:, 1], - label=bname, - alpha=0.9, - edgecolors='black', - linewidths=0.5, - ) - - if len(self.bundles) > 1: - ax.legend() - - if self.save_dir is not None: - if self.prefix_numbering: - filename = '{}_{}.png'.format( - figure_name_prefix, self.current_plot_number) - else: - filename = '{}.png'.format(figure_name_prefix) - - filename = os.path.join(self.save_dir, filename) - fig.savefig(filename) + blabel = self.bundle_mapping.get( + bname, bname) if self.bundle_mapping else bname + + self._plot_bundle( + self.axes[1], proj_data[:, 0], proj_data[:, 1], blabel) + if is_new_best: + self._plot_bundle( + self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) + + self.axes[1].set_title("Epoch {}".format(epoch)) + self._set_legend(self.axes[1], len(self.bundles)) + if is_new_best: + self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) + self._set_legend(self.axes[0], len(self.bundles)) + + if self.prefix_numbering: + filename = '{}_{}.png'.format( + figure_name_prefix, self.current_plot_number) else: - plt.show() + filename = '{}.png'.format(figure_name_prefix) + + filename = os.path.join(self.save_dir, filename) + self.fig.savefig(filename) self.current_plot_number += 1 + def _set_legend(self, ax, nb_bundles): + if nb_bundles > 1: + ax.legend(fontsize=6, loc='center left', bbox_to_anchor=(1, 0.5)) + + def _plot_bundle(self, ax, dim1, dim2, blabel): + ax.scatter( + dim1, + dim2, + label=blabel, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + ) + + def _clear_figures(self, clear_best: bool): + if clear_best: + self.axes[0].clear() + self.axes[1].clear() + + def _init_figure(self): + LOGGER.info("Init new figure for BundlesLatentSpaceVisualizer.") + fig, axes = plt.subplots(1, 2) + axes[0].set_title("Best epoch (?)") + axes[1].set_title("Last epoch (?)") + if self.fig_size is not None: + fig.set_figwidth(self.fig_size[0]) + fig.set_figheight(self.fig_size[1]) + + box_0 = axes[0].get_position() + axes[0].set_position( + [box_0.x0, box_0.y0, box_0.width * 0.8, box_0.height]) + box_1 = axes[1].get_position() + axes[1].set_position( + [box_1.x0, box_1.y0, box_1.width * 0.8, box_1.height]) + + return fig, axes + def _to_numpy(self, data): if isinstance(data, torch.Tensor): return data.detach().cpu().numpy() diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 88241ba7..e778724b 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -48,12 +48,27 @@ def prepare_arg_parser(): "This is expressed in number of epochs.") p.add_argument('--color_by', type=str, default=None, choices=['dps_bundle_index'], help="Name of the group in hdf5 to color by.") + p.add_argument('--bundles_mapping', type=str, default=None, + help="Path to a txt file mapping bundles to a new name.\n" + "Each line of that file should be: ") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) return p +def parse_bundle_mapping(bundles_mapping_file: str = None): + if bundles_mapping_file is None: + return None + + with open(bundles_mapping_file, 'r') as f: + bundle_mapping = {} + for line in f: + bundle_name, bundle_number = line.strip().split() + bundle_mapping[int(bundle_number)] = bundle_name + return bundle_mapping + + def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed @@ -124,10 +139,12 @@ def init_from_args(args, sub_loggers_level): args.experiment_name, 'latent_space_plots') os.makedirs(save_dir, exist_ok=True) + bundle_mapping = parse_bundle_mapping(args.bundles_mapping) ls_viz = BundlesLatentSpaceVisualizer(save_dir, prefix_numbering=True, - max_subset_size=1000) - current_epoch = 0 + max_subset_size=1000, + bundle_mapping=bundle_mapping) + current_epoch = -1 def visualize_latent_space(encoding, related_data): """ @@ -147,14 +164,14 @@ def visualize_latent_space(encoding, related_data): if not trainer.model.context == 'training': return - changed_epoch = current_epoch != trainer.current_epoch + changed_epoch = current_epoch != trainer.current_epoch - 1 if not changed_epoch: ls_viz.add_data_to_plot(encoding, labels=related_data) elif changed_epoch \ - and trainer.current_epoch % viz_latent_space_freq == 0: - current_epoch = trainer.current_epoch - ls_viz.plot(title="Latent space at epoch {}".format( - current_epoch)) + and (trainer.current_epoch - 1) % viz_latent_space_freq == 0: + current_epoch = trainer.current_epoch - 1 + ls_viz.plot(current_epoch, + best_epoch=trainer.best_epoch_monitor.best_epoch) ls_viz.reset_data() ls_viz.add_data_to_plot(encoding, labels=related_data) model.register_hook_post_encoding(visualize_latent_space) From 327ca86e7c55b049736fcf55541eeef084b125ab Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Thu, 3 Oct 2024 13:10:14 -0400 Subject: [PATCH 07/28] Color matching between epochs and save the plot of the best epoch --- dwi_ml/viz/latent_streamlines.py | 149 ++++++++++++++++++++++++++++++- scripts_python/ae_train_model.py | 14 +-- 2 files changed, 156 insertions(+), 7 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 9fef1e63..e8001d21 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -5,11 +5,87 @@ import numpy as np import torch +import math import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap +from matplotlib.cm import hsv LOGGER = logging.getLogger(__name__) +class ColorManager(object): + def __init__(self, max_num_bundles: int = 40): + self.bundle_color_map = {} + self.color_map = self._init_colormap(max_num_bundles) + + def _init_colormap(self, number_of_distinct_colors): + """ + Create a colormap with a number of distinct colors. + Needed to have bigger color maps for more bundles. + + Code directly copied from: + https://stackoverflow.com/questions/42697933/colormap-with-maximum-distinguishable-colours + + """ + if number_of_distinct_colors == 0: + number_of_distinct_colors = 80 + + number_of_shades = 7 + number_of_distinct_colors_with_multiply_of_shades = int( + math.ceil(number_of_distinct_colors / number_of_shades) * number_of_shades) + + # Create an array with uniformly drawn floats taken from <0, 1) partition + linearly_distributed_nums = np.arange( + number_of_distinct_colors_with_multiply_of_shades) / number_of_distinct_colors_with_multiply_of_shades + + # We are going to reorganise monotonically growing numbers in such way that there will be single array with saw-like pattern + # but each saw tooth is slightly higher than the one before + # First divide linearly_distributed_nums into number_of_shades sub-arrays containing linearly distributed numbers + arr_by_shade_rows = linearly_distributed_nums.reshape( + number_of_shades, number_of_distinct_colors_with_multiply_of_shades // number_of_shades) + + # Transpose the above matrix (columns become rows) - as a result each row contains saw tooth with values slightly higher than row above + arr_by_shade_columns = arr_by_shade_rows.T + + # Keep number of saw teeth for later + number_of_partitions = arr_by_shade_columns.shape[0] + + # Flatten the above matrix - join each row into single array + nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) + + # HSV colour map is cyclic (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic), we'll use this property + initial_cm = hsv(nums_distributed_like_rising_saw) + + lower_partitions_half = number_of_partitions // 2 + upper_partitions_half = number_of_partitions - lower_partitions_half + + # Modify lower half in such way that colours towards beginning of partition are darker + # First colours are affected more, colours closer to the middle are affected less + lower_half = lower_partitions_half * number_of_shades + for i in range(3): + initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) + + # Modify second half in such way that colours towards end of partition are less intense and brighter + # Colours closer to the middle are affected less, colours closer to the end are affected more + for i in range(3): + for j in range(upper_partitions_half): + modifier = np.ones( + number_of_shades) - initial_cm[lower_half + j * number_of_shades: lower_half + (j + 1) * number_of_shades, i] + modifier = j * modifier / upper_partitions_half + initial_cm[lower_half + j * number_of_shades: lower_half + + (j + 1) * number_of_shades, i] += modifier + + return ListedColormap(initial_cm) + + def get_color(self, label: str): + if label not in self.bundle_color_map: + self.bundle_color_map[label] = \ + self.color_map( + len(self.bundle_color_map)) + + return self.bundle_color_map[label] + + class BundlesLatentSpaceVisualizer(object): """ Utility class that wraps a t-SNE projection of the latent @@ -77,6 +153,7 @@ def __init__(self, self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} + self.bundle_color_manager = ColorManager() self.fig, self.axes = None, None self.best_epoch = -1 @@ -129,6 +206,68 @@ def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines + def check_and_register_best_epoch(self, epoch: int, best_epoch: int = -1): + """ + Finalize the epoch by plotting the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using + "add_data_to_plot". + + Parameters + ---------- + epoch: int + Current epoch. + best_epoch: int + Best epoch. + """ + is_new_best = best_epoch > self.best_epoch + + if not is_new_best: + return + + assert best_epoch == epoch, "The best epoch should be the current epoch since it just changed." + + # If we have a new best epoch, we need to update the plot on the left. + self.best_epoch = best_epoch + + for (bname, bdata) in self.bundles.items(): + if bdata.shape[0] > self.max_subset_size: + self.bundles[bname] = self._resample_max_subset_size(bdata) + + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + LOGGER.info( + "New best epoch with a total of {} streamlines".format(nb_streamlines)) + + # Build the indices for each bundle to recover the streamlines after + # the t-SNE projection. + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange( + current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + LOGGER.info("Fitting TSNE projection.") + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + if self.fig is None or self.axes is None: + self.fig, self.axes = self._init_figure() + + self.axes[0].clear() + self._plot_bundle( + self.axes[0], + all_projected_streamlines[:, 0], + all_projected_streamlines[:, 1], + 'Best epoch ({})'.format(self.best_epoch)) + + self._set_legend(self.axes[0], len(self.bundles)) + + # Clear data + self.reset_data() + def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int = -1): """ Fit and plot the t-SNE projection of the latent space streamlines. @@ -218,9 +357,14 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int self.current_plot_number += 1 - def _set_legend(self, ax, nb_bundles): + def _set_legend(self, ax, nb_bundles, order=True): if nb_bundles > 1: - ax.legend(fontsize=6, loc='center left', bbox_to_anchor=(1, 0.5)) + handles, labels = ax.get_legend_handles_labels() + if order: + labels, handles = zip( + *sorted(zip(labels, handles), key=lambda t: t[0])) + ax.legend(handles, labels, fontsize=6, + loc='center left', bbox_to_anchor=(1, 0.5)) def _plot_bundle(self, ax, dim1, dim2, blabel): ax.scatter( @@ -230,6 +374,7 @@ def _plot_bundle(self, ax, dim1, dim2, blabel): alpha=0.9, edgecolors='black', linewidths=0.5, + color=self.bundle_color_manager.get_color(blabel) ) def _clear_figures(self, clear_best: bool): diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index e778724b..0e66c2e1 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -167,12 +167,16 @@ def visualize_latent_space(encoding, related_data): changed_epoch = current_epoch != trainer.current_epoch - 1 if not changed_epoch: ls_viz.add_data_to_plot(encoding, labels=related_data) - elif changed_epoch \ - and (trainer.current_epoch - 1) % viz_latent_space_freq == 0: + elif changed_epoch: current_epoch = trainer.current_epoch - 1 - ls_viz.plot(current_epoch, - best_epoch=trainer.best_epoch_monitor.best_epoch) - ls_viz.reset_data() + if (trainer.current_epoch - 1) % viz_latent_space_freq == 0: + ls_viz.plot(current_epoch, + best_epoch=trainer.best_epoch_monitor.best_epoch) + ls_viz.reset_data() + else: + ls_viz.check_and_register_best_epoch( + current_epoch, trainer.best_epoch_monitor.best_epoch) + ls_viz.add_data_to_plot(encoding, labels=related_data) model.register_hook_post_encoding(visualize_latent_space) From 701737a9bd45142485cc7e121c61679e2627a717 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Thu, 3 Oct 2024 14:39:27 -0400 Subject: [PATCH 08/28] Fix best epoch legend and colors --- dwi_ml/viz/latent_streamlines.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index e8001d21..e62ad67f 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -257,12 +257,16 @@ def check_and_register_best_epoch(self, epoch: int, best_epoch: int = -1): self.fig, self.axes = self._init_figure() self.axes[0].clear() - self._plot_bundle( - self.axes[0], - all_projected_streamlines[:, 0], - all_projected_streamlines[:, 1], - 'Best epoch ({})'.format(self.best_epoch)) + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + blabel = self.bundle_mapping.get( + bname, bname) if self.bundle_mapping else bname + + self._plot_bundle( + self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) + self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) self._set_legend(self.axes[0], len(self.bundles)) # Clear data From 28f3f3d196f538b076443c1fcbe1cbafe9079609 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 4 Oct 2024 11:56:28 -0400 Subject: [PATCH 09/28] Cleaup data_per_streamline retrieval in the HDF5 --- .../data/dataset/multi_subject_containers.py | 21 ++-- .../data/dataset/single_subject_containers.py | 28 ++--- dwi_ml/data/dataset/streamline_containers.py | 96 ++++++++------- dwi_ml/data/dataset/utils.py | 5 +- dwi_ml/data/hdf5/hdf5_creation.py | 7 +- dwi_ml/models/projects/ae_models.py | 4 +- dwi_ml/training/batch_loaders.py | 96 ++------------- dwi_ml/training/projects/ae_trainer.py | 0 dwi_ml/training/trainers.py | 43 ++----- scripts_python/ae_train_model.py | 22 ++-- scripts_python/ae_visualize_bundles.py | 111 ------------------ 11 files changed, 119 insertions(+), 314 deletions(-) create mode 100644 dwi_ml/training/projects/ae_trainer.py delete mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index d0fd9a6c..b2b5dd79 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -38,8 +38,7 @@ class MultisubjectSubset(Dataset): iterate over data and process batches. """ - def __init__(self, set_name: str, hdf5_file: str, lazy: bool, - cache_size: int = 0, related_data_key: str = None): + def __init__(self, set_name: str, hdf5_file: str, lazy: bool, cache_size: int = 0): self.set_name = set_name self.hdf5_file = hdf5_file @@ -80,7 +79,6 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool, # This is only used in the lazy case. self.cache_size = cache_size self.volume_cache_manager = None - self.related_data_key = related_data_key def close_all_handles(self): if self.subjs_data_list.hdf_handle: @@ -280,7 +278,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None): # calling this method. logger.debug(" Creating subject '{}'.".format(subj_id)) subj_data = self._init_subj_from_hdf( - hdf_handle, subj_id, ref_group_info, related_data_key=self.related_data_key) + hdf_handle, subj_id, ref_group_info) # Add subject to the list subj_idx = self.subjs_data_list.add_subject(subj_data) @@ -337,13 +335,13 @@ def _build_empty_data_list(self): else: return SubjectsDataList(self.hdf5_file, logger) - def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info, related_data_key=None): + def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info): if self.is_lazy: return LazySubjectData.init_single_subject_from_hdf( - subject_id, hdf_handle, ref_group_info, related_data_key=related_data_key) + subject_id, hdf_handle, ref_group_info) else: return SubjectData.init_single_subject_from_hdf( - subject_id, hdf_handle, ref_group_info, related_data_key=related_data_key) + subject_id, hdf_handle, ref_group_info) class MultiSubjectDataset: @@ -360,8 +358,7 @@ class MultiSubjectDataset: """ def __init__(self, hdf5_file: str, lazy: bool, - cache_size: int = 0, log_level=None, - related_data_key=None): + cache_size: int = 0, log_level=None): """ Params ------ @@ -397,11 +394,11 @@ def __init__(self, hdf5_file: str, lazy: bool, # Preparing the testing set and validation set # In non-lazy data, the cache_size is not used. self.training_set = MultisubjectSubset( - 'training', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) + 'training', hdf5_file, self.is_lazy, cache_size) self.validation_set = MultisubjectSubset( - 'validation', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) + 'validation', hdf5_file, self.is_lazy, cache_size) self.testing_set = MultisubjectSubset( - 'testing', hdf5_file, self.is_lazy, cache_size, related_data_key=related_data_key) + 'testing', hdf5_file, self.is_lazy, cache_size) @property def params_for_checkpoint(self) -> Dict[str, Any]: diff --git a/dwi_ml/data/dataset/single_subject_containers.py b/dwi_ml/data/dataset/single_subject_containers.py index fa3639c0..78fb500b 100644 --- a/dwi_ml/data/dataset/single_subject_containers.py +++ b/dwi_ml/data/dataset/single_subject_containers.py @@ -19,7 +19,7 @@ class SubjectDataAbstract(object): """ def __init__(self, volume_groups: List[str], nb_features: List[int], - streamline_groups: List[str], subject_id: str, related_data_key: str = None): + streamline_groups: List[str], subject_id: str): """ Parameters ---------- @@ -38,7 +38,6 @@ def __init__(self, volume_groups: List[str], nb_features: List[int], self.streamline_groups = streamline_groups self.subject_id = subject_id self.is_lazy = None - self.related_data_key = related_data_key @property def mri_data_list(self) -> List[MRIDataAbstract]: @@ -52,7 +51,7 @@ def sft_data_list(self): @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): + cls, subject_id: str, hdf_file, group_info=None): """Returns an instance of this class, initiated by sending only the hdf handle. The child class's method will define how to load the data based on the child data management.""" @@ -70,8 +69,7 @@ class SubjectData(SubjectDataAbstract): def __init__(self, subject_id: str, volume_groups: List[str], nb_features: List[int], mri_data_list: List[MRIData] = None, streamline_groups: List[str] = None, - sft_data_list: List[SFTData] = None, - related_data_key: str = None): + sft_data_list: List[SFTData] = None): """ Additional params compared to super: ---- @@ -79,8 +77,7 @@ def __init__(self, subject_id: str, volume_groups: List[str], The loaded streamlines in a format copying the SFT. They contain ._data, ._offsets, ._lengths, ._lengths_mm. """ - super().__init__(volume_groups, nb_features, streamline_groups, - subject_id, related_data_key=related_data_key) + super().__init__(volume_groups, nb_features, streamline_groups, subject_id) self._mri_data_list = mri_data_list self._sft_data_list = sft_data_list self.is_lazy = False @@ -95,7 +92,7 @@ def sft_data_list(self): @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): + cls, subject_id: str, hdf_file, group_info=None): """ Instantiating a single subject data: load info and use __init__ """ @@ -117,13 +114,12 @@ def init_single_subject_from_hdf( logger.debug(" Loading streamlines group '{}'" .format(group)) sft_data = SFTData.init_sft_data_from_hdf_info( - hdf_file[subject_id][group], related_data_key=related_data_key) + hdf_file[subject_id][group]) subject_sft_data_list.append(sft_data) subj_data = cls(subject_id, volume_groups, nb_features, subject_mri_data_list, - streamline_groups, subject_sft_data_list, - related_data_key=related_data_key) + streamline_groups, subject_sft_data_list) return subj_data @@ -138,7 +134,7 @@ class LazySubjectData(SubjectDataAbstract): def __init__(self, volume_groups: List[str], nb_features: List[int], streamline_groups: List[str], subject_id: str, - hdf_handle=None, related_data_key: str = None): + hdf_handle=None): """ Additional params compared to super: ------ @@ -146,13 +142,13 @@ def __init__(self, volume_groups: List[str], nb_features: List[int], Opened hdf file, if any. If None, data loading is deactivated. """ super().__init__(volume_groups, nb_features, streamline_groups, - subject_id, related_data_key=related_data_key) + subject_id) self.hdf_handle = hdf_handle self.is_lazy = True @classmethod def init_single_subject_from_hdf( - cls, subject_id: str, hdf_file, group_info=None, related_data_key=None): + cls, subject_id: str, hdf_file, group_info=None): """ Instantiating a single subject data: NOT LOADING info and use __init__ (so in short: this does basically nothing, the lazy data is kept @@ -174,7 +170,7 @@ def init_single_subject_from_hdf( logger.debug(' Lazy: not loading data.') return cls(volume_groups, nb_features, streamline_groups, subject_id, - hdf_handle=None, related_data_key=related_data_key) + hdf_handle=None) @property def mri_data_list(self) -> Union[List[LazyMRIData], None]: @@ -208,7 +204,7 @@ def sft_data_list(self) -> Union[List[LazySFTData], None]: for group in self.streamline_groups: hdf_group = self.hdf_handle[self.subject_id][group] sft_data_list.append( - LazySFTData.init_sft_data_from_hdf_info(hdf_group, related_data_key=self.related_data_key)) + LazySFTData.init_sft_data_from_hdf_info(hdf_group)) return sft_data_list else: diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 599f8d34..c8598640 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -8,6 +8,7 @@ import h5py from nibabel.streamlines import ArraySequence import numpy as np +from collections import defaultdict def _load_space_attributes_from_hdf(hdf_group: h5py.Group): @@ -42,8 +43,9 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group): streamlines._data = np.array(hdf_group['data']) streamlines._offsets = np.array(hdf_group['offsets']) streamlines._lengths = np.array(hdf_group['lengths']) + dps_dict = _load_data_per_streamline(hdf_group) - return streamlines + return streamlines, dps_dict def _load_connectivity_info(hdf_group: h5py.Group): @@ -67,19 +69,24 @@ def _load_connectivity_info(hdf_group: h5py.Group): return contains_connectivity, connectivity_nb_blocs, connectivity_labels -def _load_related_data(hdf_group, related_data_key: str) -> Union[np.ndarray, None]: - related_data = None - # Load related data key if specified - if related_data_key is not None: +def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarray, None]: + dps_dict = defaultdict(list) + # Load only related data key if specified + dps_group = hdf_group['data_per_streamline'] + if dps_key is not None: # Make sure the related data key is in the hdf5 group - if not (related_data_key in hdf_group.keys()): - raise KeyError("The key '{}' is not in the hdf5 group." - .format(related_data_key)) + if not (dps_key in dps_group.keys()): + raise KeyError("The key '{}' is not in the hdf5 group. Keys found: {}" + .format(dps_key, dps_group.keys())) # Load the related data - related_data = np.array(hdf_group[related_data_key]).squeeze(1) + dps_dict[dps_key] = dps_group['data_per_streamline'][dps_key][:] + # Otherwise, load every dps. + else: + for dps_key in dps_group.keys(): + dps_dict[dps_key] = dps_group[dps_key][:] - return related_data + return dps_dict class _LazyStreamlinesGetter(object): @@ -96,12 +103,17 @@ def _get_one_streamline(self, idx: int): def get_array_sequence(self, item=None): if item is None: - streamlines = _load_all_streamlines_from_hdf(self.hdf_group) + streamlines, dps = _load_all_streamlines_from_hdf(self.hdf_group) else: streamlines = ArraySequence() + dps_dict = defaultdict(list) if isinstance(item, int): - streamlines.append(self._get_one_streamline(item)) + data = self._get_one_streamline(item) + streamlines.append(data) + + for dps_key in self.hdf_group['dps_keys']: + dps_dict[dps_key].append(self.hdf_group[dps_key][item]) elif isinstance(item, list) or isinstance(item, np.ndarray): # Getting a list of value from a hdf5: slow. Uses fancy indexing. @@ -111,8 +123,12 @@ def get_array_sequence(self, item=None): # Good also load the whole data and access the indexes after. # toDo Test speed for the three options. for i in item: - streamlines.append(self._get_one_streamline(i), - cache_build=True) + data = self._get_one_streamline(i) + streamlines.append(data, cache_build=True) + + for dps_key in self.hdf_group['dps_keys']: + dps_dict[dps_key].append(self.hdf_group[dps_key][item]) + streamlines.finalize_append() elif isinstance(item, slice): @@ -121,13 +137,17 @@ def get_array_sequence(self, item=None): for offset, length in zip(offsets, lengths): streamline = self.hdf_group['data'][offset:offset + length] streamlines.append(streamline, cache_build=True) + + for dps_key in self.hdf_group['dps_keys']: + dps_dict[dps_key].append( + self.hdf_group[dps_key][offset:offset + length]) streamlines.finalize_append() else: raise ValueError('Item should be either a int, list, ' 'np.ndarray or slice but we received {}' .format(type(item))) - return streamlines + return streamlines, dps @property def lengths(self): @@ -179,8 +199,7 @@ class SFTDataAbstract(object): def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, contains_connectivity: bool, connectivity_nb_blocs: List = None, - connectivity_labels: np.ndarray = None, - related_data: np.ndarray = None): + connectivity_labels: np.ndarray = None): """ The lazy/non-lazy versions will have more parameters, such as the streamlines, the connectivity_matrix. In the case of the lazy version, @@ -211,7 +230,6 @@ def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, self.contains_connectivity = contains_connectivity self.connectivity_nb_blocs = connectivity_nb_blocs self.connectivity_labels = connectivity_labels - self.related_data = related_data def __len__(self): raise NotImplementedError @@ -263,14 +281,6 @@ def _get_streamlines_as_list(self, streamline_ids) -> List[ArraySequence]: the hdf5.""" raise NotImplementedError - def get_related_data(self, - streamline_ids: Union[List[int], int, slice, None] = None): - """Returns the data related to the streamlines.""" - if self.related_data is None: - return None - else: - return self.related_data[streamline_ids] - def as_sft(self, streamline_ids: Union[List[int], int, slice, None] = None) \ -> StatefulTractogram: @@ -282,10 +292,11 @@ def as_sft(self, streamline_ids: Union[List[int], int, slice, None] List of chosen ids. If None, use all streamlines. """ - streamlines = self._get_streamlines_as_list(streamline_ids) + streamlines, dps = self._get_streamlines_as_list(streamline_ids) sft = StatefulTractogram(streamlines, self.space_attributes, - self.space, self.origin) + self.space, self.origin, + data_per_streamline=dps) return sft @@ -293,6 +304,7 @@ def as_sft(self, class SFTData(SFTDataAbstract): def __init__(self, streamlines: ArraySequence, lengths_mm: List, connectivity_matrix: np.ndarray, + data_per_streamline: np.ndarray = None, **kwargs): """ streamlines: ArraySequence or LazyStreamlinesGetter @@ -305,6 +317,7 @@ def __init__(self, streamlines: ArraySequence, self._lengths_mm = lengths_mm self._connectivity_matrix = connectivity_matrix self.is_lazy = False + self.data_per_streamline = data_per_streamline def __len__(self): return len(self.streamlines) @@ -327,12 +340,12 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self._connectivity_matrix @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: str = None): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, dps_key: str = None): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. """ - streamlines = _load_all_streamlines_from_hdf(hdf_group) + streamlines, dps_dict = _load_all_streamlines_from_hdf(hdf_group) # Adding non-hidden parameters for nicer later access lengths_mm = hdf_group['euclidean_lengths'] @@ -344,8 +357,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: st else: connectivity_matrix = None - related_data = _load_related_data(hdf_group, related_data_key) - space_attributes, space, origin = _load_space_attributes_from_hdf( hdf_group) @@ -358,13 +369,17 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: st contains_connectivity=contains_connectivity, connectivity_nb_blocs=connectivity_nb_blocs, connectivity_labels=connectivity_labels, - related_data=related_data) + data_per_streamline=dps_dict) def _get_streamlines_as_list(self, streamline_ids): if streamline_ids is not None: - return self.streamlines.__getitem__(streamline_ids) + dps_indexed = {} + for key, value in self.data_per_streamline.items(): + dps_indexed[key] = value[streamline_ids] + + return self.streamlines.__getitem__(streamline_ids), dps_indexed else: - return self.streamlines + return self.streamlines, self.data_per_streamline class LazySFTData(SFTDataAbstract): @@ -397,7 +412,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self.streamlines_getter.connectivity_matrix(indxyz) @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: str = None): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space_attributes, space, origin = _load_space_attributes_from_hdf( hdf_group) @@ -406,17 +421,14 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, related_data_key: st streamlines = _LazyStreamlinesGetter(hdf_group) - related_data = _load_related_data(hdf_group, related_data_key) - return cls(streamlines_getter=streamlines, space_attributes=space_attributes, space=space, origin=origin, contains_connectivity=contains_connectivity, connectivity_nb_blocs=connectivity_nb_blocs, - connectivity_labels=connectivity_labels, - related_data=related_data) + connectivity_labels=connectivity_labels) def _get_streamlines_as_list(self, streamline_ids): - streamlines = self.streamlines_getter.get_array_sequence( + streamlines, dps = self.streamlines_getter.get_array_sequence( streamline_ids) - return streamlines + return streamlines, dps diff --git a/dwi_ml/data/dataset/utils.py b/dwi_ml/data/dataset/utils.py index fdea92bb..d8a7b0e9 100644 --- a/dwi_ml/data/dataset/utils.py +++ b/dwi_ml/data/dataset/utils.py @@ -7,8 +7,7 @@ def prepare_multisubjectdataset(args, load_training=True, load_validation=True, load_testing=True, - log_level=logging.root.level, - related_data_key=None): + log_level=logging.root.level): """ Instantiates a MultiSubjectDataset AND loads data. @@ -20,7 +19,7 @@ def prepare_multisubjectdataset(args, load_training=True, load_validation=True, with Timer("\nPreparing datasets", newline=True, color='blue'): dataset = MultiSubjectDataset( args.hdf5_file, lazy=args.lazy, cache_size=args.cache_size, - log_level=log_level, related_data_key=related_data_key) + log_level=log_level) dataset.load_data(load_training, load_validation, load_testing) logging.info("Number of subjects loaded: \n" diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 06fb25a0..e02a099e 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -630,6 +630,8 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') + dps_group = streamlines_group.create_group('data_per_streamline') + for dps_key in self.dps_keys: if dps_key not in sft.data_per_streamline: raise ValueError( @@ -638,8 +640,9 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, logging.debug( " Include dps \"{}\" in the HDF5.".format(dps_key)) - streamlines_group.create_dataset('dps_' + dps_key, - data=sft.data_per_streamline[dps_key]) + + dps_group.create_dataset( + dps_key, data=sft.data_per_streamline[dps_key]) # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 8608f339..ecbce674 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -106,7 +106,7 @@ def pre_pad(m): def forward(self, input_streamlines: List[torch.tensor], - related_data=None + data_per_streamline: dict = None ): """Run the model on a batch of sequences. @@ -127,7 +127,7 @@ def forward(self, encoded = self.encode(input_streamlines) for hook in self.post_encoding_hooks: - hook(encoded, related_data) + hook(encoded, data_per_streamline) x = self.decode(encoded) return x diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 275e5273..f7adc598 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,6 +58,16 @@ logger = logging.getLogger('batch_loader_logger') +def _dps_to_tensors(dps: dict, device='cpu'): + """ + Convert a list of DPS to a list of tensors. + """ + dps_tensors = {} + for key, value in dps.items(): + dps_tensors[key] = torch.tensor(value, device=device) + return dps_tensors + + class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, @@ -323,91 +333,9 @@ def load_batch_streamlines( batch_streamlines.extend(sft.streamlines) batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] + data_per_streamline = _dps_to_tensors(sft.data_per_streamline) - return batch_streamlines, final_s_ids_per_subj - - def load_batch_streamlines_and_related( - self, streamline_ids_per_subj: List[Tuple[int, list]]): - """ - Fetches the chosen streamlines for all subjects in batch. - Pocesses data augmentation. - - Torch uses this function to process the data with the dataloader - parallel workers (on cpu). To be used as collate_fn. - - Parameters - ---------- - streamline_ids_per_subj: List[Tuple[int, list]] - The list of streamline ids for each subject (relative ids inside - each subject's tractogram) for this batch. - - Returns - ------- - (batch_streamlines, final_s_ids_per_subj) - Where - - batch_streamlines: list[torch.tensor] - The new streamlines after data augmentation, IN VOXEL SPACE, - CORNER. - - final_s_ids_per_subj: Dict[int, slice] - The new streamline ids per subj in this augmented batch. - """ - if self.context is None: - raise ValueError("Context must be set prior to using the batch " - "loader.") - - # The batch's streamline ids will change throughout processing because - # of data augmentation, so we need to do it subject by subject to - # keep track of the streamline ids. These final ids will correspond to - # the loaded, processed streamlines, not to the ids in the hdf5 file. - final_s_ids_per_subj = defaultdict(slice) - batch_streamlines = [] - streamlines_related_data = [] - for subj, s_ids in streamline_ids_per_subj: - logger.debug( - " Data loader: Processing data preparation for " - "subj {} (preparing {} streamlines)".format(subj, len(s_ids))) - - # No cache for the sft data. Accessing it directly. - # Note: If this is used through the dataloader, multiprocessing - # is used. Each process will open a handle. - subj_data = \ - self.context_subset.subjs_data_list.get_subj_with_handle(subj) - subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx] - - # Get streamlines as sft - logger.debug(" Loading sampled streamlines...") - sft = subj_sft_data.as_sft(s_ids) - - # TODO: modify this list consequently to the data augmentations. - # Currently, if the data augmentation adds/removes streamlines, - # the related data won't match the streamlines list anymore. - related_data = subj_sft_data.get_related_data( - s_ids) # Can return None - sft = self._data_augmentation_sft(sft) - - # Remember the indices of this subject's (augmented) streamlines - ids_start = len(batch_streamlines) - ids_end = ids_start + len(sft) - final_s_ids_per_subj[subj] = slice(ids_start, ids_end) - - # Add all (augmented) streamlines to the batch - # What we want is the streamline coordinates, to eventually get - # the underlying input(s). Sending to vox and to corner to - # be able to use our trilinear interpolation - sft.to_vox() - sft.to_corner() - batch_streamlines.extend(sft.streamlines) - - if related_data is not None: - streamlines_related_data.extend(related_data) - - batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] - - if len(streamlines_related_data) > 0: - assert len(streamlines_related_data) == len(batch_streamlines), \ - "Related data should have the same length as the streamlines." - - return batch_streamlines, final_s_ids_per_subj, streamlines_related_data + return batch_streamlines, final_s_ids_per_subj, data_per_streamline def load_batch_connectivity_matrices( self, streamline_ids_per_subj: Dict[int, slice]): diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py new file mode 100644 index 00000000..e69de29b diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index f89aeb20..0f357234 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -64,8 +64,7 @@ def __init__(self, nb_cpu_processes: int = 0, use_gpu: bool = False, clip_grad: float = None, comet_workspace: str = None, comet_project: str = None, - from_checkpoint: bool = False, log_level=logging.root.level, - related_data_retrieval: bool = False): + from_checkpoint: bool = False, log_level=logging.root.level): """ Parameters ---------- @@ -230,15 +229,12 @@ def __init__(self, # dataloader output is on GPU, ready to be fed to the model. # Otherwise, dataloader output is kept on CPU, and the main thread # sends volumes and coords on GPU for interpolation. - self.related_data_retrieval = related_data_retrieval - self.collate_fn = self.batch_loader.load_batch_streamlines_and_related \ - if self.related_data_retrieval else self.batch_loader.load_batch_streamlines logger.debug("- Instantiating dataloaders...") self.train_dataloader = DataLoader( dataset=self.batch_sampler.dataset.training_set, batch_sampler=self.batch_sampler, num_workers=self.nb_cpu_processes, - collate_fn=self.collate_fn, + collate_fn=self.batch_loader.load_batch_streamlines, pin_memory=self.use_gpu) self.valid_dataloader = None if self.use_validation: @@ -246,7 +242,7 @@ def __init__(self, dataset=self.batch_sampler.dataset.validation_set, batch_sampler=self.batch_sampler, num_workers=self.nb_cpu_processes, - collate_fn=self.collate_fn, + collate_fn=self.batch_loader.load_batch_streamlines, pin_memory=self.use_gpu) # ---------------------- @@ -824,17 +820,10 @@ def train_one_epoch(self, epoch): train_iterator = enumerate(pbar) for batch_id, data in train_iterator: - if self.related_data_retrieval: - related_data = data[2] - data = data[:2] - else: - related_data = None - # Enable gradients for backpropagation. Uses torch's module # train(), which "turns on" the training mode. with grad_context(): - mean_loss = self.train_one_batch( - data, related_data=related_data) + mean_loss = self.train_one_batch(data) unclipped_grad_norm, grad_norm = self.back_propagation( mean_loss) @@ -897,16 +886,10 @@ def validate_one_epoch(self, epoch): valid_iterator = enumerate(pbar) for batch_id, data in valid_iterator: - if self.related_data_retrieval: - related_data = data[2] - data = data[:2] - else: - related_data = None - # Validate this batch: forward propagation + loss with torch.no_grad(): self.validate_one_batch( - data, epoch, related_data=related_data) + data, epoch) # Break if maximum number of epochs has been reached if batch_id == self.nb_batches_valid - 1: @@ -932,25 +915,23 @@ def validate_one_epoch(self, epoch): monitor.end_epoch() self._update_comet_after_epoch('validation', epoch) - def train_one_batch(self, data, **model_kwargs): + def train_one_batch(self, data): """ Computes the loss for the current batch and updates monitors. Returns the loss to be used for backpropagation. """ # Encapsulated for easier management of child classes. - mean_local_loss, n = self.run_one_batch( - data, **model_kwargs) + mean_local_loss, n = self.run_one_batch(data) # mean loss is a Tensor of a single value. item() converts to float self.train_loss_monitor.update(mean_local_loss.cpu().item(), weight=n) return mean_local_loss - def validate_one_batch(self, data, epoch, **model_kwargs): + def validate_one_batch(self, data, epoch): """ Computes the loss(es) for the current batch and updates monitors. """ - mean_local_loss, n = self.run_one_batch( - data, **model_kwargs) + mean_local_loss, n = self.run_one_batch(data) self.valid_local_loss_monitor.update(mean_local_loss.cpu().item(), weight=n) @@ -1019,7 +1000,7 @@ def _save_best_model(self): json_file.write(json.dumps(best_losses, indent=4, separators=(',', ': '))) - def run_one_batch(self, data, **model_kwargs): + def run_one_batch(self, data): """ Runs a batch of data through the model (calling its forward method) and returns the mean loss. @@ -1041,7 +1022,7 @@ def run_one_batch(self, data, **model_kwargs): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1063,7 +1044,7 @@ def run_one_batch(self, data, **model_kwargs): # but ok, shouldn't be too heavy. Easier to deal with multiple # projects' requirements by sending whole streamlines rather # than only directions. - model_outputs = self.model(streamlines_f, **model_kwargs) + model_outputs = self.model(streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 0e66c2e1..f0ceb499 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -77,8 +77,7 @@ def init_from_args(args, sub_loggers_level): # Prepare the dataset dataset = prepare_multisubjectdataset(args, load_testing=False, - log_level=sub_loggers_level, - related_data_key=color_by) + log_level=sub_loggers_level) # Preparing the model # (Direction getter) @@ -128,8 +127,7 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=sub_loggers_level, - related_data_retrieval=color_by is not None) + log_level=sub_loggers_level) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) @@ -146,7 +144,7 @@ def init_from_args(args, sub_loggers_level): bundle_mapping=bundle_mapping) current_epoch = -1 - def visualize_latent_space(encoding, related_data): + def visualize_latent_space(encoding, data_per_streamline): """ This is not a clean way to do it. This would require changes in the trainer to allow for a callback system where we could @@ -164,9 +162,11 @@ def visualize_latent_space(encoding, related_data): if not trainer.model.context == 'training': return + bundle_index = data_per_streamline['bundle_index'].squeeze(1) + changed_epoch = current_epoch != trainer.current_epoch - 1 if not changed_epoch: - ls_viz.add_data_to_plot(encoding, labels=related_data) + ls_viz.add_data_to_plot(encoding, labels=bundle_index) elif changed_epoch: current_epoch = trainer.current_epoch - 1 if (trainer.current_epoch - 1) % viz_latent_space_freq == 0: @@ -177,7 +177,7 @@ def visualize_latent_space(encoding, related_data): ls_viz.check_and_register_best_epoch( current_epoch, trainer.best_epoch_monitor.best_epoch) - ls_viz.add_data_to_plot(encoding, labels=related_data) + ls_viz.add_data_to_plot(encoding, labels=bundle_index) model.register_hook_post_encoding(visualize_latent_space) return trainer @@ -199,10 +199,10 @@ def main(): assert_outputs_exist(p, args, args.experiments_path) # Verify if a checkpoint has been saved. Else create an experiment. - # if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, - # "checkpoint")): - # raise FileExistsError("This experiment already exists. Delete or use " - # "script ae_resume_training_from_checkpoint.py.") + if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError("This experiment already exists. Delete or use " + "script ae_resume_training_from_checkpoint.py.") trainer = init_from_args(args, sub_loggers_level) diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py deleted file mode 100644 index 68863d01..00000000 --- a/scripts_python/ae_visualize_bundles.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import argparse -import logging -import pathlib -import torch -import numpy as np -from glob import glob -from os.path import expanduser -from dipy.tracking.streamline import set_number_of_points - -from scilpy.io.utils import (add_overwrite_arg, - assert_outputs_exist, - add_reference_arg, - add_verbose_arg) -from scilpy.io.streamlines import load_tractogram_with_reference -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) -from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer - - -def _build_arg_parser(): - p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, - description=__doc__) - # Mandatory - # Should only be False for debugging tests. - add_arg_existing_experiment_path(p) - # Add_args_testing_subj_hdf5(p) - - p.add_argument('in_bundles', - help="The 'glob' path to several bundles identified " - "by their file name." - "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") - - # Options - p.add_argument('--batch_size', type=int) - add_memory_args(p) - - p.add_argument('--pick_at_random', action='store_true') - add_reference_arg(p) - add_overwrite_arg(p) - add_verbose_arg(p) - return p - - -def load_bundles(p, args, files_list: list): - bundles = [] - for bundle_file in files_list: - bundle_sft = load_tractogram_with_reference(p, args, bundle_file) - bundle_sft.to_vox() - bundle_sft.to_corner() - bundles.append(bundle_sft) - return bundles - - -def main(): - p = _build_arg_parser() - args = p.parse_args() - - # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, - # but we will set trainer to user-defined level. - sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - - # General logging (ex, scilpy: Warning) - logging.getLogger().setLevel(level=logging.WARNING) - - # Verify output names - # Check experiment_path exists and best_model folder exists - # Assert_inputs_exist(p, args.hdf5_file) - assert_outputs_exist(p, args, []) - - # Device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # 1. Load model - logging.debug("Loading model.") - model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) - model = model.to(device) - - expanded = expanduser(args.in_bundles) - bundles_files = glob(expanded) - if isinstance(bundles_files, str): - bundles_files = [bundles_files] - - bundles_label = [pathlib.Path(l).stem for l in bundles_files] - bundles_sft = load_bundles(p, args, bundles_files) - - logging.info("Running model to compute loss") - - ls_viz = BundlesLatentSpaceVisualizer( - save_path="/home/local/USHERBROOKE/levj1404/Documents/dwi_ml/data/out.png" - ) - - with torch.no_grad(): - for i, bundle_sft in enumerate(bundles_sft): - - # Resample - streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), - dtype=torch.float32, device=device) - - latent_streamlines = model.encode( - streamlines).cpu().numpy() # output of (N, 32) - ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) - - ls_viz.plot() - - -if __name__ == '__main__': - main() From a228c867058579e393ee855494bd39c0a9208ca3 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 4 Oct 2024 14:35:19 -0400 Subject: [PATCH 10/28] Cleanup: part 1 --- dwi_ml/training/projects/ae_trainer.py | 96 ++++++++++++++++++++++++++ dwi_ml/viz/latent_streamlines.py | 28 ++++++-- scripts_python/ae_train_model.py | 76 +++----------------- 3 files changed, 127 insertions(+), 73 deletions(-) diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py index e69de29b..abbc1ee1 100644 --- a/dwi_ml/training/projects/ae_trainer.py +++ b/dwi_ml/training/projects/ae_trainer.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +import logging +import os +from typing import Union, List + +from dwi_ml.models.main_models import MainModelAbstract +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def parse_bundle_mapping(bundles_mapping_file: str = None): + if bundles_mapping_file is None: + return None + + with open(bundles_mapping_file, 'r') as f: + bundle_mapping = {} + for line in f: + bundle_name, bundle_number = line.strip().split() + bundle_mapping[int(bundle_number)] = bundle_name + return bundle_mapping + + +class TrainerWithBundleDPS(DWIMLAbstractTrainer): + + def __init__(self, + model: MainModelAbstract, experiments_path: str, + experiment_name: str, batch_sampler: DWIMLBatchIDSampler, + batch_loader: DWIMLStreamlinesBatchLoader, + learning_rates: Union[List, float] = None, + weight_decay: float = 0.01, + optimizer: str = 'Adam', max_epochs: int = 10, + max_batches_per_epoch_training: int = 1000, + max_batches_per_epoch_validation: Union[int, None] = 1000, + patience: int = None, patience_delta: float = 1e-6, + nb_cpu_processes: int = 0, use_gpu: bool = False, + clip_grad: float = None, + comet_workspace: str = None, comet_project: str = None, + from_checkpoint: bool = False, log_level=logging.root.level, + ls_viz_freq: int = None, color_by: str = None, + bundles_mapping_file: str = None, max_viz_subset_size: int = 1000): + + super().__init__(model, experiments_path, experiment_name, batch_sampler, batch_loader, + learning_rates, weight_decay, optimizer, max_epochs, max_batches_per_epoch_training, + max_batches_per_epoch_validation, patience, patience_delta, nb_cpu_processes, use_gpu, + clip_grad, comet_workspace, comet_project, from_checkpoint, log_level) + + self.ls_viz_freq = ls_viz_freq + self.color_by = color_by + self.visualize = ls_viz_freq is not None + if self.visualize: + # Setup to visualize latent space + save_dir = os.path.join( + experiments_path, experiment_name, 'latent_space_plots') + os.makedirs(save_dir, exist_ok=True) + + bundle_mapping = parse_bundle_mapping(bundles_mapping_file) + self.ls_viz = BundlesLatentSpaceVisualizer(save_dir, + prefix_numbering=True, + max_subset_size=max_viz_subset_size, + bundle_mapping=bundle_mapping) + + # Register what to do post encoding. + def visualize_latent_space(encoding, data_per_streamline): + # Only execute the following if we are in training + if not self.model.context == 'training': + return + + if self.color_by is None: + bundle_index = None + else: + bundle_index = \ + data_per_streamline[self.color_by].squeeze(1) + + self.ls_viz.add_data_to_plot(encoding, labels=bundle_index) + model.register_hook_post_encoding(visualize_latent_space) + + def train_one_epoch(self, epoch): + super().train_one_epoch(epoch) + + # Do things post epoch + if self.visualize: + current_epoch = self.current_epoch + best_epoch = self.best_epoch_monitor.best_epoch \ + if self.best_epoch_monitor.best_epoch is not None else current_epoch + if current_epoch % self.ls_viz_freq == 0: + self.ls_viz.plot(current_epoch, best_epoch=best_epoch) + self.ls_viz.reset_data() + else: + + # TODO: This is problematic, the best epoch is updated actually after + # validation (which is after doing the training epoch) + + self.ls_viz.check_and_register_best_epoch( + current_epoch, best_epoch=best_epoch) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index e62ad67f..df597a5b 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -12,6 +12,8 @@ LOGGER = logging.getLogger(__name__) +DEFAULT_BUNDLE_NAME = 'UNK' + class ColorManager(object): def __init__(self, max_num_bundles: int = 40): @@ -180,13 +182,25 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]): Labels for each streamline. """ latent_space_streamlines = self._to_numpy(data) - - all_labels = np.unique(labels) - for label in all_labels: - label_indices = labels == label - label_data = latent_space_streamlines[label_indices] - label_data = self._resample_max_subset_size(label_data) - self.bundles[label] = label_data + if labels is None: + self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines + else: + all_labels = np.unique(labels) + _remaining_indices = np.arange(len(labels)) + for label in all_labels: + label_indices = labels[_remaining_indices] == label + label_data = latent_space_streamlines[_remaining_indices][label_indices] + label_data = self._resample_max_subset_size(label_data) + self.bundles[label] = label_data + + _remaining_indices = _remaining_indices[~label_indices] + + if len(_remaining_indices) > 0: + LOGGER.warning( + "Some streamlines were not considered in the bundles," + "some labels are missing.\n" + "Added them to the {} bundle.".format(DEFAULT_BUNDLE_NAME)) + self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[all_indices] def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): """ diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index f0ceb499..202ffe3f 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -22,7 +22,7 @@ from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.projects.ae_trainer import TrainerWithBundleDPS from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) @@ -46,8 +46,10 @@ def prepare_arg_parser(): p.add_argument('--viz_latent_space_freq', type=int, default=None, help="Frequency at which to visualize latent space.\n" "This is expressed in number of epochs.") - p.add_argument('--color_by', type=str, default=None, choices=['dps_bundle_index'], - help="Name of the group in hdf5 to color by.") + p.add_argument('--color_by', type=str, default=None, + help="Name of the group in hdf5 to color by." + "In the HDF5, the coloring group should be under" + "data_per_streamline/") p.add_argument('--bundles_mapping', type=str, default=None, help="Path to a txt file mapping bundles to a new name.\n" "Each line of that file should be: ") @@ -57,18 +59,6 @@ def prepare_arg_parser(): return p -def parse_bundle_mapping(bundles_mapping_file: str = None): - if bundles_mapping_file is None: - return None - - with open(bundles_mapping_file, 'r') as f: - bundle_mapping = {} - for line in f: - bundle_name, bundle_number = line.strip().split() - bundle_mapping[int(bundle_number)] = bundle_name - return bundle_mapping - - def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed @@ -111,7 +101,7 @@ def init_from_args(args, sub_loggers_level): # Instantiate trainer with Timer("\n\nPreparing trainer", newline=True, color='red'): lr = format_lr(args.learning_rate) - trainer = DWIMLAbstractTrainer( + trainer = TrainerWithBundleDPS( model=model, experiments_path=args.experiments_path, experiment_name=args.experiment_name, batch_sampler=batch_sampler, batch_loader=batch_loader, @@ -127,59 +117,13 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=sub_loggers_level) + log_level=sub_loggers_level, + ls_viz_freq=viz_latent_space_freq, color_by=color_by, + bundles_mapping_file=args.bundles_mapping, + max_viz_subset_size=1000) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) - if viz_latent_space_freq is not None: - # Setup to visualize latent space - save_dir = os.path.join(args.experiments_path, - args.experiment_name, 'latent_space_plots') - os.makedirs(save_dir, exist_ok=True) - - bundle_mapping = parse_bundle_mapping(args.bundles_mapping) - ls_viz = BundlesLatentSpaceVisualizer(save_dir, - prefix_numbering=True, - max_subset_size=1000, - bundle_mapping=bundle_mapping) - current_epoch = -1 - - def visualize_latent_space(encoding, data_per_streamline): - """ - This is not a clean way to do it. This would require changes in - the trainer to allow for a callback system where we could - register a function to be called at the end of each epoch to - plot the latent space of the data accumulated during the epoch - (at each batch). - - Also, using this method, the latent space of the last epoch will - not be plotted. We would need to calculate which batch step would - be the last in the epoch and then plot accordingly. - """ - nonlocal current_epoch, trainer, ls_viz - - # Only execute the following if we are in training - if not trainer.model.context == 'training': - return - - bundle_index = data_per_streamline['bundle_index'].squeeze(1) - - changed_epoch = current_epoch != trainer.current_epoch - 1 - if not changed_epoch: - ls_viz.add_data_to_plot(encoding, labels=bundle_index) - elif changed_epoch: - current_epoch = trainer.current_epoch - 1 - if (trainer.current_epoch - 1) % viz_latent_space_freq == 0: - ls_viz.plot(current_epoch, - best_epoch=trainer.best_epoch_monitor.best_epoch) - ls_viz.reset_data() - else: - ls_viz.check_and_register_best_epoch( - current_epoch, trainer.best_epoch_monitor.best_epoch) - - ls_viz.add_data_to_plot(encoding, labels=bundle_index) - model.register_hook_post_encoding(visualize_latent_space) - return trainer From 1d823567f7bff7e1d566f4bb971c4f73cd39dd5c Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Sat, 5 Oct 2024 17:16:22 -0400 Subject: [PATCH 11/28] Cleanup: part 2 --- dwi_ml/training/projects/ae_trainer.py | 37 ++++--- dwi_ml/training/utils/monitoring.py | 11 +++ dwi_ml/viz/latent_streamlines.py | 132 ++++++++++++------------- 3 files changed, 95 insertions(+), 85 deletions(-) diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py index abbc1ee1..657b84d7 100644 --- a/dwi_ml/training/projects/ae_trainer.py +++ b/dwi_ml/training/projects/ae_trainer.py @@ -62,8 +62,8 @@ def __init__(self, bundle_mapping=bundle_mapping) # Register what to do post encoding. - def visualize_latent_space(encoding, data_per_streamline): - # Only execute the following if we are in training + def handle_latent_encodings(encoding, data_per_streamline): + # Only accumulate data during training if not self.model.context == 'training': return @@ -74,23 +74,22 @@ def visualize_latent_space(encoding, data_per_streamline): data_per_streamline[self.color_by].squeeze(1) self.ls_viz.add_data_to_plot(encoding, labels=bundle_index) - model.register_hook_post_encoding(visualize_latent_space) + # Execute the above function within the model's forward(). + model.register_hook_post_encoding(handle_latent_encodings) - def train_one_epoch(self, epoch): - super().train_one_epoch(epoch) + # Plot the latent space after each best epoch. + # Called after running training & validation epochs. + self.best_epoch_monitor.register_on_best_epoch_hook( + self.ls_viz.plot) - # Do things post epoch + def train_one_epoch(self, epoch): if self.visualize: - current_epoch = self.current_epoch - best_epoch = self.best_epoch_monitor.best_epoch \ - if self.best_epoch_monitor.best_epoch is not None else current_epoch - if current_epoch % self.ls_viz_freq == 0: - self.ls_viz.plot(current_epoch, best_epoch=best_epoch) - self.ls_viz.reset_data() - else: - - # TODO: This is problematic, the best epoch is updated actually after - # validation (which is after doing the training epoch) - - self.ls_viz.check_and_register_best_epoch( - current_epoch, best_epoch=best_epoch) + # Before starting another training epoch, make sure the data + # is cleared. This is important to avoid accumulating data. + # We have to do it here. Since the on_new_best_epoch is called + # after the validation epoch, we can't do it there. + # Also, we won't always have the best epoch, if not, we still need + # to clear the data. + self.ls_viz.reset_data() + + super().train_one_epoch(epoch) diff --git a/dwi_ml/training/utils/monitoring.py b/dwi_ml/training/utils/monitoring.py index 79086528..b4ac5991 100644 --- a/dwi_ml/training/utils/monitoring.py +++ b/dwi_ml/training/utils/monitoring.py @@ -158,6 +158,14 @@ def __init__(self, name, patience: int, patience_delta: float = 1e-6): self.best_value = None self.best_epoch = None self.n_bad_epochs = None + self.on_best_epoch_hooks = [] + + def register_on_best_epoch_hook(self, hook): + self.on_best_epoch_hooks.append(hook) + + def _call_on_best_epoch_hooks(self, new_best_epoch): + for hook in self.on_best_epoch_hooks: + hook(new_best_epoch) def update(self, loss, epoch): """ @@ -178,12 +186,14 @@ def update(self, loss, epoch): self.best_value = loss self.best_epoch = epoch self.n_bad_epochs = 0 + self._call_on_best_epoch_hooks(epoch) return False elif loss < self.best_value - self.min_eps: # Improving from at least eps. self.best_value = loss self.best_epoch = epoch self.n_bad_epochs = 0 + self._call_on_best_epoch_hooks(epoch) return False else: # Not improving enough @@ -242,6 +252,7 @@ class IterTimer(object): # next iter could be twice as long as usual: time.time() + iter_timer.mean * 2.0 + 30 > max_time """ + def __init__(self, history_len=5): self.history = deque(maxlen=history_len) self.iterable = None diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index df597a5b..ef265da8 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -168,6 +168,7 @@ def reset_data(self): # Not sure if resetting the TSNE object is necessary. self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} + self.should_call_reset_before_plot = False def add_data_to_plot(self, data: np.ndarray, labels: List[str]): """ @@ -200,7 +201,7 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]): "Some streamlines were not considered in the bundles," "some labels are missing.\n" "Added them to the {} bundle.".format(DEFAULT_BUNDLE_NAME)) - self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[all_indices] + self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[_remaining_indices] def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): """ @@ -220,71 +221,70 @@ def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines - def check_and_register_best_epoch(self, epoch: int, best_epoch: int = -1): - """ - Finalize the epoch by plotting the t-SNE projection of the latent space streamlines. - This should be called once after adding all the data to plot using - "add_data_to_plot". - - Parameters - ---------- - epoch: int - Current epoch. - best_epoch: int - Best epoch. - """ - is_new_best = best_epoch > self.best_epoch - - if not is_new_best: - return - - assert best_epoch == epoch, "The best epoch should be the current epoch since it just changed." - - # If we have a new best epoch, we need to update the plot on the left. - self.best_epoch = best_epoch - - for (bname, bdata) in self.bundles.items(): - if bdata.shape[0] > self.max_subset_size: - self.bundles[bname] = self._resample_max_subset_size(bdata) - - nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - LOGGER.info( - "New best epoch with a total of {} streamlines".format(nb_streamlines)) - - # Build the indices for each bundle to recover the streamlines after - # the t-SNE projection. - bundles_indices = {} - current_start = 0 - for (bname, bdata) in self.bundles.items(): - bundles_indices[bname] = np.arange( - current_start, current_start + bdata.shape[0]) - current_start += bdata.shape[0] - - assert current_start == nb_streamlines - - all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - - LOGGER.info("Fitting TSNE projection.") - all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - - if self.fig is None or self.axes is None: - self.fig, self.axes = self._init_figure() - - self.axes[0].clear() - for (bname, bdata) in self.bundles.items(): - bindices = bundles_indices[bname] - proj_data = all_projected_streamlines[bindices] - blabel = self.bundle_mapping.get( - bname, bname) if self.bundle_mapping else bname - - self._plot_bundle( - self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) - - self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) - self._set_legend(self.axes[0], len(self.bundles)) - - # Clear data - self.reset_data() + # def update_best_epoch(self, epoch: int): + # """ + # Finalize the epoch by plotting the t-SNE projection of the latent space streamlines. + # This should be called once after adding all the data to plot using + # "add_data_to_plot". + + # Parameters + # ---------- + # epoch: int + # Current epoch. + # best_epoch: int + # Best epoch. + # """ + # if epoch == self.best_epoch: + # LOGGER.warning( + # "The current epoch is the same as the best epoch. " + # "Skipping plot update.") + # return + + # # If we have a new best epoch, we need to update the plot on the left. + # self.best_epoch = epoch + + # for (bname, bdata) in self.bundles.items(): + # if bdata.shape[0] > self.max_subset_size: + # self.bundles[bname] = self._resample_max_subset_size(bdata) + + # nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + # LOGGER.info( + # "New best epoch with a total of {} streamlines".format(nb_streamlines)) + + # # Build the indices for each bundle to recover the streamlines after + # # the t-SNE projection. + # bundles_indices = {} + # current_start = 0 + # for (bname, bdata) in self.bundles.items(): + # bundles_indices[bname] = np.arange( + # current_start, current_start + bdata.shape[0]) + # current_start += bdata.shape[0] + + # assert current_start == nb_streamlines + + # all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + # LOGGER.info("Fitting TSNE projection.") + # all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + # if self.fig is None or self.axes is None: + # self.fig, self.axes = self._init_figure() + + # self.axes[0].clear() + # for (bname, bdata) in self.bundles.items(): + # bindices = bundles_indices[bname] + # proj_data = all_projected_streamlines[bindices] + # blabel = self.bundle_mapping.get( + # bname, bname) if self.bundle_mapping else bname + + # self._plot_bundle( + # self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) + + # self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) + # self._set_legend(self.axes[0], len(self.bundles)) + + # # Clear data + # self.reset_data() def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int = -1): """ From 115a7dc1220e324dcdaf9ed86360aca28cee63a7 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Sat, 5 Oct 2024 17:27:43 -0400 Subject: [PATCH 12/28] Cleanup: part 3 --- dwi_ml/viz/latent_streamlines.py | 121 +++++-------------------------- 1 file changed, 18 insertions(+), 103 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index ef265da8..60b624bb 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -105,7 +105,7 @@ class BundlesLatentSpaceVisualizer(object): def __init__(self, save_dir: str, - fig_size: Union[List, Tuple] = (16, 8), + fig_size: Union[List, Tuple] = (10, 8), random_state: int = 42, max_subset_size: int = None, prefix_numbering: bool = False, @@ -157,8 +157,7 @@ def __init__(self, self.bundles = {} self.bundle_color_manager = ColorManager() - self.fig, self.axes = None, None - self.best_epoch = -1 + self.fig, self.ax = None, None def reset_data(self): """ @@ -221,72 +220,7 @@ def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines - # def update_best_epoch(self, epoch: int): - # """ - # Finalize the epoch by plotting the t-SNE projection of the latent space streamlines. - # This should be called once after adding all the data to plot using - # "add_data_to_plot". - - # Parameters - # ---------- - # epoch: int - # Current epoch. - # best_epoch: int - # Best epoch. - # """ - # if epoch == self.best_epoch: - # LOGGER.warning( - # "The current epoch is the same as the best epoch. " - # "Skipping plot update.") - # return - - # # If we have a new best epoch, we need to update the plot on the left. - # self.best_epoch = epoch - - # for (bname, bdata) in self.bundles.items(): - # if bdata.shape[0] > self.max_subset_size: - # self.bundles[bname] = self._resample_max_subset_size(bdata) - - # nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - # LOGGER.info( - # "New best epoch with a total of {} streamlines".format(nb_streamlines)) - - # # Build the indices for each bundle to recover the streamlines after - # # the t-SNE projection. - # bundles_indices = {} - # current_start = 0 - # for (bname, bdata) in self.bundles.items(): - # bundles_indices[bname] = np.arange( - # current_start, current_start + bdata.shape[0]) - # current_start += bdata.shape[0] - - # assert current_start == nb_streamlines - - # all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - - # LOGGER.info("Fitting TSNE projection.") - # all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - - # if self.fig is None or self.axes is None: - # self.fig, self.axes = self._init_figure() - - # self.axes[0].clear() - # for (bname, bdata) in self.bundles.items(): - # bindices = bundles_indices[bname] - # proj_data = all_projected_streamlines[bindices] - # blabel = self.bundle_mapping.get( - # bname, bname) if self.bundle_mapping else bname - - # self._plot_bundle( - # self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) - - # self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) - # self._set_legend(self.axes[0], len(self.bundles)) - - # # Clear data - # self.reset_data() - - def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int = -1): + def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'): """ Fit and plot the t-SNE projection of the latent space streamlines. This should be called once after adding all the data to plot using @@ -330,21 +264,14 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int assert current_start == nb_streamlines - all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - LOGGER.info("Fitting TSNE projection.") + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - if self.fig is None or self.axes is None: - self.fig, self.axes = self._init_figure() - - # Check if we have a new best epoch. - # If so, that means we have to update the plot on the left. - is_new_best = best_epoch > self.best_epoch - if is_new_best: - self.best_epoch = best_epoch + if self.fig is None or self.ax is None: + self.fig, self.ax = self._init_figure() - self._clear_figures(is_new_best) + self._clear_figures() for (bname, bdata) in self.bundles.items(): bindices = bundles_indices[bname] @@ -353,20 +280,14 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space', best_epoch: int bname, bname) if self.bundle_mapping else bname self._plot_bundle( - self.axes[1], proj_data[:, 0], proj_data[:, 1], blabel) - if is_new_best: - self._plot_bundle( - self.axes[0], proj_data[:, 0], proj_data[:, 1], blabel) + self.ax, proj_data[:, 0], proj_data[:, 1], blabel) - self.axes[1].set_title("Epoch {}".format(epoch)) - self._set_legend(self.axes[1], len(self.bundles)) - if is_new_best: - self.axes[0].set_title("Best epoch ({})".format(self.best_epoch)) - self._set_legend(self.axes[0], len(self.bundles)) + self.ax.set_title("Best epoch {}".format(epoch)) + self._set_legend(self.ax, len(self.bundles)) if self.prefix_numbering: filename = '{}_{}.png'.format( - figure_name_prefix, self.current_plot_number) + figure_name_prefix, epoch) else: filename = '{}.png'.format(figure_name_prefix) @@ -395,28 +316,22 @@ def _plot_bundle(self, ax, dim1, dim2, blabel): color=self.bundle_color_manager.get_color(blabel) ) - def _clear_figures(self, clear_best: bool): - if clear_best: - self.axes[0].clear() - self.axes[1].clear() + def _clear_figures(self): + self.ax.clear() def _init_figure(self): LOGGER.info("Init new figure for BundlesLatentSpaceVisualizer.") - fig, axes = plt.subplots(1, 2) - axes[0].set_title("Best epoch (?)") - axes[1].set_title("Last epoch (?)") + fig, ax = plt.subplots(1, 1) + ax.set_title("Best epoch (?)") if self.fig_size is not None: fig.set_figwidth(self.fig_size[0]) fig.set_figheight(self.fig_size[1]) - box_0 = axes[0].get_position() - axes[0].set_position( + box_0 = ax.get_position() + ax.set_position( [box_0.x0, box_0.y0, box_0.width * 0.8, box_0.height]) - box_1 = axes[1].get_position() - axes[1].set_position( - [box_1.x0, box_1.y0, box_1.width * 0.8, box_1.height]) - return fig, axes + return fig, ax def _to_numpy(self, data): if isinstance(data, torch.Tensor): From 1b1499c02190a56c53fff37f7bae748d769258db Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Sat, 5 Oct 2024 21:26:22 -0400 Subject: [PATCH 13/28] Cleanup: part 4 --- dwi_ml/data/dataset/multi_subject_containers.py | 11 ++++------- dwi_ml/data/dataset/single_subject_containers.py | 6 ++---- dwi_ml/viz/latent_streamlines.py | 5 +---- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index b2b5dd79..0c6bded5 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -37,8 +37,8 @@ class MultisubjectSubset(Dataset): Based on torch's dataset class. Provides functions for a DataLoader to iterate over data and process batches. """ - - def __init__(self, set_name: str, hdf5_file: str, lazy: bool, cache_size: int = 0): + def __init__(self, set_name: str, hdf5_file: str, lazy: bool, + cache_size: int = 0): self.set_name = set_name self.hdf5_file = hdf5_file @@ -356,7 +356,6 @@ class MultiSubjectDataset: datasets 'streamlines/data', 'streamlines/offsets', 'streamlines/lengths', 'streamlines/euclidean_lengths'. """ - def __init__(self, hdf5_file: str, lazy: bool, cache_size: int = 0, log_level=None): """ @@ -479,8 +478,7 @@ def load_data(self, load_training=True, load_validation=True, self.nb_features = nb_features if streamline_groups is not None: - missing_str = np.setdiff1d( - streamline_groups, poss_strea_groups) + missing_str = np.setdiff1d(streamline_groups, poss_strea_groups) if len(missing_str) > 0: raise ValueError("Streamlines {} were not found in the " "first subject of your hdf5 file." @@ -500,8 +498,7 @@ def load_data(self, load_training=True, load_validation=True, self.streamline_groups, self.streamlines_contain_connectivity) self.training_set.set_subset_info(*group_info, step_size, compress) - self.validation_set.set_subset_info( - *group_info, step_size, compress) + self.validation_set.set_subset_info(*group_info, step_size, compress) self.testing_set.set_subset_info(*group_info, step_size, compress) # LOADING diff --git a/dwi_ml/data/dataset/single_subject_containers.py b/dwi_ml/data/dataset/single_subject_containers.py index 78fb500b..fbd9b6cc 100644 --- a/dwi_ml/data/dataset/single_subject_containers.py +++ b/dwi_ml/data/dataset/single_subject_containers.py @@ -17,7 +17,6 @@ class SubjectDataAbstract(object): single MRI acquisition. It could contain data from many "real" MRI volumes concatenated together. """ - def __init__(self, volume_groups: List[str], nb_features: List[int], streamline_groups: List[str], subject_id: str): """ @@ -65,7 +64,6 @@ def add_handle(self, hdf_handle): class SubjectData(SubjectDataAbstract): """Non-lazy version""" - def __init__(self, subject_id: str, volume_groups: List[str], nb_features: List[int], mri_data_list: List[MRIData] = None, streamline_groups: List[str] = None, @@ -77,7 +75,8 @@ def __init__(self, subject_id: str, volume_groups: List[str], The loaded streamlines in a format copying the SFT. They contain ._data, ._offsets, ._lengths, ._lengths_mm. """ - super().__init__(volume_groups, nb_features, streamline_groups, subject_id) + super().__init__(volume_groups, nb_features, streamline_groups, + subject_id) self._mri_data_list = mri_data_list self._sft_data_list = sft_data_list self.is_lazy = False @@ -131,7 +130,6 @@ class LazySubjectData(SubjectDataAbstract): """ Lazy version. """ - def __init__(self, volume_groups: List[str], nb_features: List[int], streamline_groups: List[str], subject_id: str, hdf_handle=None): diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 60b624bb..66925cd7 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -271,7 +271,7 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'): if self.fig is None or self.ax is None: self.fig, self.ax = self._init_figure() - self._clear_figures() + self.ax.clear() for (bname, bdata) in self.bundles.items(): bindices = bundles_indices[bname] @@ -316,9 +316,6 @@ def _plot_bundle(self, ax, dim1, dim2, blabel): color=self.bundle_color_manager.get_color(blabel) ) - def _clear_figures(self): - self.ax.clear() - def _init_figure(self): LOGGER.info("Init new figure for BundlesLatentSpaceVisualizer.") fig, ax = plt.subplots(1, 1) From 39ecd7b15ad800cc6799dbb01eb0a0e5aa863385 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Sat, 5 Oct 2024 21:50:01 -0400 Subject: [PATCH 14/28] Cleanup: part 5 --- dwi_ml/data/dataset/streamline_containers.py | 7 +++++-- dwi_ml/training/projects/ae_trainer.py | 18 +++++++++++++----- scripts_python/ae_train_model.py | 7 +++---- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index c8598640..034eccec 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -72,6 +72,9 @@ def _load_connectivity_info(hdf_group: h5py.Group): def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarray, None]: dps_dict = defaultdict(list) # Load only related data key if specified + if not 'data_per_streamline' in hdf_group.keys(): + return dps_dict + dps_group = hdf_group['data_per_streamline'] if dps_key is not None: # Make sure the related data key is in the hdf5 group @@ -79,8 +82,8 @@ def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarra raise KeyError("The key '{}' is not in the hdf5 group. Keys found: {}" .format(dps_key, dps_group.keys())) - # Load the related data - dps_dict[dps_key] = dps_group['data_per_streamline'][dps_key][:] + # Load the related data per streamline + dps_dict[dps_key] = dps_group[dps_key][:] # Otherwise, load every dps. else: for dps_key in dps_group.keys(): diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py index 657b84d7..cc873176 100644 --- a/dwi_ml/training/projects/ae_trainer.py +++ b/dwi_ml/training/projects/ae_trainer.py @@ -9,6 +9,8 @@ from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer +LOGGER = logging.getLogger(__name__) + def parse_bundle_mapping(bundles_mapping_file: str = None): if bundles_mapping_file is None: @@ -38,7 +40,7 @@ def __init__(self, clip_grad: float = None, comet_workspace: str = None, comet_project: str = None, from_checkpoint: bool = False, log_level=logging.root.level, - ls_viz_freq: int = None, color_by: str = None, + viz_latent_space: bool = False, color_by: str = None, bundles_mapping_file: str = None, max_viz_subset_size: int = 1000): super().__init__(model, experiments_path, experiment_name, batch_sampler, batch_loader, @@ -46,10 +48,9 @@ def __init__(self, max_batches_per_epoch_validation, patience, patience_delta, nb_cpu_processes, use_gpu, clip_grad, comet_workspace, comet_project, from_checkpoint, log_level) - self.ls_viz_freq = ls_viz_freq self.color_by = color_by - self.visualize = ls_viz_freq is not None - if self.visualize: + self.viz_latent_space = viz_latent_space + if self.viz_latent_space: # Setup to visualize latent space save_dir = os.path.join( experiments_path, experiment_name, 'latent_space_plots') @@ -60,6 +61,7 @@ def __init__(self, prefix_numbering=True, max_subset_size=max_viz_subset_size, bundle_mapping=bundle_mapping) + self.warning_printed = False # Register what to do post encoding. def handle_latent_encodings(encoding, data_per_streamline): @@ -69,6 +71,12 @@ def handle_latent_encodings(encoding, data_per_streamline): if self.color_by is None: bundle_index = None + elif not self.color_by in data_per_streamline.keys(): + if not self.warning_printed: + LOGGER.warning( + f"Coloring by {self.color_by} not found in data_per_streamline.") + self.warning_printed = True + bundle_index = None else: bundle_index = \ data_per_streamline[self.color_by].squeeze(1) @@ -83,7 +91,7 @@ def handle_latent_encodings(encoding, data_per_streamline): self.ls_viz.plot) def train_one_epoch(self, epoch): - if self.visualize: + if self.viz_latent_space: # Before starting another training epoch, make sure the data # is cleared. This is important to avoid accumulating data. # We have to do it here. Since the on_new_best_epoch is called diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 202ffe3f..184ef40f 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -29,7 +29,6 @@ from dwi_ml.training.utils.trainer import (add_training_args, run_experiment, format_lr) from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader -from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) @@ -43,7 +42,7 @@ def prepare_arg_parser(): add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") - p.add_argument('--viz_latent_space_freq', type=int, default=None, + p.add_argument('--viz_latent_space', action='store_true', default=False, help="Frequency at which to visualize latent space.\n" "This is expressed in number of epochs.") p.add_argument('--color_by', type=str, default=None, @@ -62,7 +61,7 @@ def prepare_arg_parser(): def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed - viz_latent_space_freq = args.viz_latent_space_freq + viz_latent_space = args.viz_latent_space color_by = args.color_by # Prepare the dataset @@ -118,7 +117,7 @@ def init_from_args(args, sub_loggers_level): # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, log_level=sub_loggers_level, - ls_viz_freq=viz_latent_space_freq, color_by=color_by, + viz_latent_space=viz_latent_space, color_by=color_by, bundles_mapping_file=args.bundles_mapping, max_viz_subset_size=1000) logging.info("Trainer params : " + From 2e4a19168252df34dc24e7b6ba82b23666efcbfd Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 09:59:58 -0400 Subject: [PATCH 15/28] Fix dps unpacking --- dwi_ml/models/projects/transformer_models.py | 7 ++++++- dwi_ml/training/trainers.py | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index a93a5c8e..7f883778 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -117,6 +117,7 @@ class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter, https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ the embedding probably adapts to leave place for the positional encoding. """ + def __init__(self, experiment_name: str, # Target preprocessing params for the batch loader + tracker @@ -358,7 +359,9 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len): return mask_future, mask_padding def forward(self, inputs: List[torch.tensor], - input_streamlines: List[torch.tensor] = None): + input_streamlines: List[torch.tensor] = None, + data_per_streamline: Union[List[torch.tensor], + List[np.ndarray]] = None): """ Params ------ @@ -823,6 +826,7 @@ class OriginalTransformerModel(AbstractTransformerModelWithTarget): emb_choice_x """ + def __init__(self, input_embedded_size, n_layers_d: int, **kw): """ d_model = input_embedded_size = target_embedded_size. @@ -964,6 +968,7 @@ class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget): [ emb_choice_x ; emb_choice_y ] """ + def __init__(self, **kw): """ No additional params. d_model = input size + target size. diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 0f357234..9b19a8d6 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1157,7 +1157,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1185,7 +1185,8 @@ def run_one_batch(self, data): # (batch loader will do it depending on training / valid) streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) - model_outputs = self.model(batch_inputs, streamlines_f) + model_outputs = self.model( + batch_inputs, streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') From 2de0a43c5f26f1279b3346d96e0cd5e2c8395889 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 11:18:50 -0400 Subject: [PATCH 16/28] Fix tests --- dwi_ml/data/dataset/streamline_containers.py | 29 ++++++++++++------- dwi_ml/models/projects/learn2track_model.py | 4 ++- dwi_ml/training/trainers_withGV.py | 2 +- .../utils/data_and_models_for_tests.py | 5 ++-- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 034eccec..d78f01f4 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -106,17 +106,25 @@ def _get_one_streamline(self, idx: int): def get_array_sequence(self, item=None): if item is None: - streamlines, dps = _load_all_streamlines_from_hdf(self.hdf_group) + streamlines, data_per_streamline = _load_all_streamlines_from_hdf( + self.hdf_group) else: streamlines = ArraySequence() - dps_dict = defaultdict(list) + data_per_streamline = defaultdict(list) + + # If data_per_streamline is not in the hdf5, use an empty dict + # so that we don't add anything to the data_per_streamline in the + # following steps. + hdf_dps_group = self.hdf_group['data_per_streamline'] if \ + 'data_per_streamline' in self.hdf_group.keys() else {} if isinstance(item, int): data = self._get_one_streamline(item) streamlines.append(data) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append(self.hdf_group[dps_key][item]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) elif isinstance(item, list) or isinstance(item, np.ndarray): # Getting a list of value from a hdf5: slow. Uses fancy indexing. @@ -129,8 +137,9 @@ def get_array_sequence(self, item=None): data = self._get_one_streamline(i) streamlines.append(data, cache_build=True) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append(self.hdf_group[dps_key][item]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) streamlines.finalize_append() @@ -141,16 +150,16 @@ def get_array_sequence(self, item=None): streamline = self.hdf_group['data'][offset:offset + length] streamlines.append(streamline, cache_build=True) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append( - self.hdf_group[dps_key][offset:offset + length]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][offset:offset + length]) streamlines.finalize_append() else: raise ValueError('Item should be either a int, list, ' 'np.ndarray or slice but we received {}' .format(type(item))) - return streamlines, dps + return streamlines, data_per_streamline @property def lengths(self): diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 9ba8074c..d3b11237 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -227,6 +227,7 @@ def computed_params_for_display(self): def forward(self, x: List[torch.tensor], input_streamlines: List[torch.tensor] = None, + data_per_streamline: List[torch.tensor] = {}, hidden_recurrent_states: List = None, return_hidden=False, point_idx: int = None): """Run the model on a batch of sequences. @@ -284,7 +285,8 @@ def forward(self, x: List[torch.tensor], unsorted_indices = invert_permutation(sorted_indices) x = [x[i] for i in sorted_indices] if input_streamlines is not None: - input_streamlines = [input_streamlines[i] for i in sorted_indices] + input_streamlines = [input_streamlines[i] + for i in sorted_indices] # ==== 0. Previous dirs. n_prev_dirs = None diff --git a/dwi_ml/training/trainers_withGV.py b/dwi_ml/training/trainers_withGV.py index a0aebfcb..c42ee214 100644 --- a/dwi_ml/training/trainers_withGV.py +++ b/dwi_ml/training/trainers_withGV.py @@ -242,7 +242,7 @@ def gv_phase_one_batch(self, data, compute_all_scores=False): seeds and first few segments. Expected results are the batch's validation streamlines. """ - real_lines, ids_per_subj = data + real_lines, ids_per_subj, data_per_streamline = data # Possibly sending again to GPU even if done in the local loss # computation, but easier with current implementation. diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index f1bcf6c0..aaa830db 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -84,7 +84,7 @@ def compute_loss(self, model_outputs, target_streamlines=None, else: return torch.zeros(n, device=self.device), 1 - def forward(self, inputs: list, streamlines): + def forward(self, inputs: list, streamlines, data_per_streamline): # Not using streamlines. Pretending to use inputs. _ = self.fake_parameter regressed_dir = torch.as_tensor([1., 1., 1.]) @@ -143,7 +143,8 @@ def get_tracking_directions(self, regressed_dirs, algo, raise ValueError("'algo' should be 'det' or 'prob'.") def forward(self, inputs: List[torch.tensor], - target_streamlines: List[torch.tensor]): + target_streamlines: List[torch.tensor], + data_per_streamline: List[torch.tensor]): # Previous dirs if self.nb_previous_dirs > 0: target_dirs = compute_directions(target_streamlines) From 7f3d931938989ab226067c7cda3dcaa149c32f2c Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 11:45:10 -0400 Subject: [PATCH 17/28] Fix pep8 --- dwi_ml/data/dataset/streamline_containers.py | 9 +- dwi_ml/training/projects/ae_trainer.py | 31 ++++--- dwi_ml/viz/latent_streamlines.py | 87 ++++++++++++-------- 3 files changed, 75 insertions(+), 52 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index d78f01f4..79ab283a 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -69,17 +69,18 @@ def _load_connectivity_info(hdf_group: h5py.Group): return contains_connectivity, connectivity_nb_blocs, connectivity_labels -def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarray, None]: +def _load_data_per_streamline(hdf_group, + dps_key: str = None) -> Union[np.ndarray, None]: dps_dict = defaultdict(list) # Load only related data key if specified - if not 'data_per_streamline' in hdf_group.keys(): + if 'data_per_streamline' not in hdf_group.keys(): return dps_dict dps_group = hdf_group['data_per_streamline'] if dps_key is not None: # Make sure the related data key is in the hdf5 group if not (dps_key in dps_group.keys()): - raise KeyError("The key '{}' is not in the hdf5 group. Keys found: {}" + raise KeyError("The key '{}' is not in the hdf5 group. Found: {}" .format(dps_key, dps_group.keys())) # Load the related data per streamline @@ -352,7 +353,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self._connectivity_matrix @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, dps_key: str = None): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py index cc873176..e74abba5 100644 --- a/dwi_ml/training/projects/ae_trainer.py +++ b/dwi_ml/training/projects/ae_trainer.py @@ -41,12 +41,17 @@ def __init__(self, comet_workspace: str = None, comet_project: str = None, from_checkpoint: bool = False, log_level=logging.root.level, viz_latent_space: bool = False, color_by: str = None, - bundles_mapping_file: str = None, max_viz_subset_size: int = 1000): - - super().__init__(model, experiments_path, experiment_name, batch_sampler, batch_loader, - learning_rates, weight_decay, optimizer, max_epochs, max_batches_per_epoch_training, - max_batches_per_epoch_validation, patience, patience_delta, nb_cpu_processes, use_gpu, - clip_grad, comet_workspace, comet_project, from_checkpoint, log_level) + bundles_mapping_file: str = None, + max_viz_subset_size: int = 1000): + + super().__init__(model, experiments_path, experiment_name, + batch_sampler, batch_loader, learning_rates, + weight_decay, optimizer, max_epochs, + max_batches_per_epoch_training, + max_batches_per_epoch_validation, patience, + patience_delta, nb_cpu_processes, use_gpu, + clip_grad, comet_workspace, comet_project, + from_checkpoint, log_level) self.color_by = color_by self.viz_latent_space = viz_latent_space @@ -57,10 +62,11 @@ def __init__(self, os.makedirs(save_dir, exist_ok=True) bundle_mapping = parse_bundle_mapping(bundles_mapping_file) - self.ls_viz = BundlesLatentSpaceVisualizer(save_dir, - prefix_numbering=True, - max_subset_size=max_viz_subset_size, - bundle_mapping=bundle_mapping) + self.ls_viz = BundlesLatentSpaceVisualizer( + save_dir, + prefix_numbering=True, + max_subset_size=max_viz_subset_size, + bundle_mapping=bundle_mapping) self.warning_printed = False # Register what to do post encoding. @@ -71,10 +77,11 @@ def handle_latent_encodings(encoding, data_per_streamline): if self.color_by is None: bundle_index = None - elif not self.color_by in data_per_streamline.keys(): + elif self.color_by not in data_per_streamline.keys(): if not self.warning_printed: LOGGER.warning( - f"Coloring by {self.color_by} not found in data_per_streamline.") + f"Coloring by {self.color_by} not found in " + "data_per_streamline.") self.warning_printed = True bundle_index = None else: diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 66925cd7..5977923b 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -20,33 +20,40 @@ def __init__(self, max_num_bundles: int = 40): self.bundle_color_map = {} self.color_map = self._init_colormap(max_num_bundles) - def _init_colormap(self, number_of_distinct_colors): + def _init_colormap(self, nb_distinct_colors: int): """ Create a colormap with a number of distinct colors. Needed to have bigger color maps for more bundles. - Code directly copied from: - https://stackoverflow.com/questions/42697933/colormap-with-maximum-distinguishable-colours + Code directly copied from: + https://stackoverflow.com/questions/42697933 """ - if number_of_distinct_colors == 0: - number_of_distinct_colors = 80 + if nb_distinct_colors == 0: + nb_distinct_colors = 80 - number_of_shades = 7 - number_of_distinct_colors_with_multiply_of_shades = int( - math.ceil(number_of_distinct_colors / number_of_shades) * number_of_shades) + nb_of_shades = 7 + nb_of_distinct_colors_with_mult_of_shades = int( + math.ceil(nb_distinct_colors / nb_of_shades) + * nb_of_shades) - # Create an array with uniformly drawn floats taken from <0, 1) partition + # Create an array with uniformly drawn floats taken from <0, 1) + # partition linearly_distributed_nums = np.arange( - number_of_distinct_colors_with_multiply_of_shades) / number_of_distinct_colors_with_multiply_of_shades - - # We are going to reorganise monotonically growing numbers in such way that there will be single array with saw-like pattern - # but each saw tooth is slightly higher than the one before - # First divide linearly_distributed_nums into number_of_shades sub-arrays containing linearly distributed numbers + nb_of_distinct_colors_with_mult_of_shades) / \ + nb_of_distinct_colors_with_mult_of_shades + + # We are going to reorganise monotonically growing numbers in such way + # that there will be single array with saw-like pattern but each saw + # tooth is slightly higher than the one before. First divide + # linearly_distributed_nums into nb_of_shades sub-arrays containing + # linearly distributed numbers. arr_by_shade_rows = linearly_distributed_nums.reshape( - number_of_shades, number_of_distinct_colors_with_multiply_of_shades // number_of_shades) + nb_of_shades, nb_of_distinct_colors_with_mult_of_shades // + nb_of_shades) - # Transpose the above matrix (columns become rows) - as a result each row contains saw tooth with values slightly higher than row above + # Transpose the above matrix (columns become rows) - as a result each + # row contains saw tooth with values slightly higher than row above arr_by_shade_columns = arr_by_shade_rows.T # Keep number of saw teeth for later @@ -55,27 +62,31 @@ def _init_colormap(self, number_of_distinct_colors): # Flatten the above matrix - join each row into single array nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) - # HSV colour map is cyclic (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic), we'll use this property + # HSV colour map is cyclic we'll use this property + # (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic) initial_cm = hsv(nums_distributed_like_rising_saw) lower_partitions_half = number_of_partitions // 2 upper_partitions_half = number_of_partitions - lower_partitions_half - # Modify lower half in such way that colours towards beginning of partition are darker - # First colours are affected more, colours closer to the middle are affected less - lower_half = lower_partitions_half * number_of_shades + # Modify lower half in such way that colours towards beginning of + # partition are darker .First colours are affected more, colours + # closer to the middle are affected less + lower_half = lower_partitions_half * nb_of_shades for i in range(3): initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) - # Modify second half in such way that colours towards end of partition are less intense and brighter - # Colours closer to the middle are affected less, colours closer to the end are affected more + # Modify second half in such way that colours towards end of partition + # are less intense and brighter. Colours closer to the middle are + # affected less, colours closer to the end are affected more for i in range(3): for j in range(upper_partitions_half): - modifier = np.ones( - number_of_shades) - initial_cm[lower_half + j * number_of_shades: lower_half + (j + 1) * number_of_shades, i] + modifier = np.ones(nb_of_shades) \ + - initial_cm[lower_half + j * nb_of_shades: + lower_half + (j + 1) * nb_of_shades, i] modifier = j * modifier / upper_partitions_half - initial_cm[lower_half + j * number_of_shades: lower_half + - (j + 1) * number_of_shades, i] += modifier + initial_cm[lower_half + j * nb_of_shades: lower_half + + (j + 1) * nb_of_shades, i] += modifier return ListedColormap(initial_cm) @@ -186,21 +197,24 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]): self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines else: all_labels = np.unique(labels) - _remaining_indices = np.arange(len(labels)) + remaining_indices = np.arange(len(labels)) for label in all_labels: - label_indices = labels[_remaining_indices] == label - label_data = latent_space_streamlines[_remaining_indices][label_indices] + label_indices = labels[remaining_indices] == label + label_data = \ + latent_space_streamlines[remaining_indices][label_indices] label_data = self._resample_max_subset_size(label_data) self.bundles[label] = label_data - _remaining_indices = _remaining_indices[~label_indices] + remaining_indices = remaining_indices[~label_indices] - if len(_remaining_indices) > 0: + if len(remaining_indices) > 0: LOGGER.warning( "Some streamlines were not considered in the bundles," "some labels are missing.\n" - "Added them to the {} bundle.".format(DEFAULT_BUNDLE_NAME)) - self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[_remaining_indices] + "Added them to the {} bundle." + .format(DEFAULT_BUNDLE_NAME)) + self.bundles[DEFAULT_BUNDLE_NAME] = \ + latent_space_streamlines[remaining_indices] def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): """ @@ -244,7 +258,8 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'): # So that the warning above is only displayed once. self.should_call_reset_before_plot = True - # Start by making sure the number of streamlines doesn't exceed the threshold. + # Start by making sure the number of streamlines doesn't + # exceed the threshold. for (bname, bdata) in self.bundles.items(): if bdata.shape[0] > self.max_subset_size: self.bundles[bname] = self._resample_max_subset_size(bdata) @@ -347,8 +362,8 @@ def _resample_max_subset_size(self, data: np.ndarray): "A max_subset_size of an integer value greater" "than 0 is required.") - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. + # Only sample if we need to reduce the number of latent + # streamlines to show on the plot. if (len(data) > self.max_subset_size): sample_indices = np.random.choice( len(data), From e051f191da2d1c48aca3b5fd69ce555b50649771 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 15:25:57 -0400 Subject: [PATCH 18/28] Move color generation func into separate file --- dwi_ml/viz/latent_streamlines.py | 88 +++++--------------------------- dwi_ml/viz/utils.py | 74 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 76 deletions(-) create mode 100644 dwi_ml/viz/utils.py diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 5977923b..6f5f7114 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -4,11 +4,8 @@ from sklearn.manifold import TSNE import numpy as np import torch - -import math import matplotlib.pyplot as plt -from matplotlib.colors import ListedColormap -from matplotlib.cm import hsv +from dwi_ml.viz.utils import generate_dissimilar_color_map LOGGER = logging.getLogger(__name__) @@ -16,79 +13,15 @@ class ColorManager(object): + """ + Utility class to manage the color of the bundles in the latent space. + This way, we can have a consistent color for each bundle across + different plots. + """ + def __init__(self, max_num_bundles: int = 40): self.bundle_color_map = {} - self.color_map = self._init_colormap(max_num_bundles) - - def _init_colormap(self, nb_distinct_colors: int): - """ - Create a colormap with a number of distinct colors. - Needed to have bigger color maps for more bundles. - - Code directly copied from: - https://stackoverflow.com/questions/42697933 - - """ - if nb_distinct_colors == 0: - nb_distinct_colors = 80 - - nb_of_shades = 7 - nb_of_distinct_colors_with_mult_of_shades = int( - math.ceil(nb_distinct_colors / nb_of_shades) - * nb_of_shades) - - # Create an array with uniformly drawn floats taken from <0, 1) - # partition - linearly_distributed_nums = np.arange( - nb_of_distinct_colors_with_mult_of_shades) / \ - nb_of_distinct_colors_with_mult_of_shades - - # We are going to reorganise monotonically growing numbers in such way - # that there will be single array with saw-like pattern but each saw - # tooth is slightly higher than the one before. First divide - # linearly_distributed_nums into nb_of_shades sub-arrays containing - # linearly distributed numbers. - arr_by_shade_rows = linearly_distributed_nums.reshape( - nb_of_shades, nb_of_distinct_colors_with_mult_of_shades // - nb_of_shades) - - # Transpose the above matrix (columns become rows) - as a result each - # row contains saw tooth with values slightly higher than row above - arr_by_shade_columns = arr_by_shade_rows.T - - # Keep number of saw teeth for later - number_of_partitions = arr_by_shade_columns.shape[0] - - # Flatten the above matrix - join each row into single array - nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) - - # HSV colour map is cyclic we'll use this property - # (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic) - initial_cm = hsv(nums_distributed_like_rising_saw) - - lower_partitions_half = number_of_partitions // 2 - upper_partitions_half = number_of_partitions - lower_partitions_half - - # Modify lower half in such way that colours towards beginning of - # partition are darker .First colours are affected more, colours - # closer to the middle are affected less - lower_half = lower_partitions_half * nb_of_shades - for i in range(3): - initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) - - # Modify second half in such way that colours towards end of partition - # are less intense and brighter. Colours closer to the middle are - # affected less, colours closer to the end are affected more - for i in range(3): - for j in range(upper_partitions_half): - modifier = np.ones(nb_of_shades) \ - - initial_cm[lower_half + j * nb_of_shades: - lower_half + (j + 1) * nb_of_shades, i] - modifier = j * modifier / upper_partitions_half - initial_cm[lower_half + j * nb_of_shades: lower_half + - (j + 1) * nb_of_shades, i] += modifier - - return ListedColormap(initial_cm) + self.color_map = generate_dissimilar_color_map(max_num_bundles) def get_color(self, label: str): if label not in self.bundle_color_map: @@ -142,6 +75,9 @@ def __init__(self, reset_warning: bool If True, a warning will be displayed when calling "plot"several times without calling "reset_data" in between to clear the data. + bundle_mapping: dict + Mapping of the bundle names to the labels to display on the plot. + (e.g. key: bundle index, value: bundle name) """ self.save_dir = save_dir @@ -189,7 +125,7 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]): ---------- data: str Unprojected latent space streamlines (N, latent_space_dim). - label: np.ndarray + labels: np.ndarray Labels for each streamline. """ latent_space_streamlines = self._to_numpy(data) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py new file mode 100644 index 00000000..ec02b71d --- /dev/null +++ b/dwi_ml/viz/utils.py @@ -0,0 +1,74 @@ +import math +from matplotlib.colors import ListedColormap +from matplotlib.cm import hsv + + +def generate_dissimilar_color_map(nb_distinct_colors: int): + """ + Create a colormap with a number of distinct colors. + Needed to have bigger color maps for more bundles. + + Code directly copied from: + https://stackoverflow.com/questions/42697933 + + """ + if nb_distinct_colors == 0: + nb_distinct_colors = 80 + + nb_of_shades = 7 + nb_of_distinct_colors_with_mult_of_shades = int( + math.ceil(nb_distinct_colors / nb_of_shades) + * nb_of_shades) + + # Create an array with uniformly drawn floats taken from <0, 1) + # partition + linearly_distributed_nums = np.arange( + nb_of_distinct_colors_with_mult_of_shades) / \ + nb_of_distinct_colors_with_mult_of_shades + + # We are going to reorganise monotonically growing numbers in such way + # that there will be single array with saw-like pattern but each saw + # tooth is slightly higher than the one before. First divide + # linearly_distributed_nums into nb_of_shades sub-arrays containing + # linearly distributed numbers. + arr_by_shade_rows = linearly_distributed_nums.reshape( + nb_of_shades, nb_of_distinct_colors_with_mult_of_shades // + nb_of_shades) + + # Transpose the above matrix (columns become rows) - as a result each + # row contains saw tooth with values slightly higher than row above + arr_by_shade_columns = arr_by_shade_rows.T + + # Keep number of saw teeth for later + number_of_partitions = arr_by_shade_columns.shape[0] + + # Flatten the above matrix - join each row into single array + nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) + + # HSV colour map is cyclic we'll use this property + # (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic) + initial_cm = hsv(nums_distributed_like_rising_saw) + + lower_partitions_half = number_of_partitions // 2 + upper_partitions_half = number_of_partitions - lower_partitions_half + + # Modify lower half in such way that colours towards beginning of + # partition are darker .First colours are affected more, colours + # closer to the middle are affected less + lower_half = lower_partitions_half * nb_of_shades + for i in range(3): + initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) + + # Modify second half in such way that colours towards end of partition + # are less intense and brighter. Colours closer to the middle are + # affected less, colours closer to the end are affected more + for i in range(3): + for j in range(upper_partitions_half): + modifier = np.ones(nb_of_shades) \ + - initial_cm[lower_half + j * nb_of_shades: + lower_half + (j + 1) * nb_of_shades, i] + modifier = j * modifier / upper_partitions_half + initial_cm[lower_half + j * nb_of_shades: lower_half + + (j + 1) * nb_of_shades, i] += modifier + + return ListedColormap(initial_cm) From a6a9a67e1d929823508c6b7f66a36d1bd751b904 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 15:35:29 -0400 Subject: [PATCH 19/28] Doc update --- dwi_ml/models/projects/ae_models.py | 4 ++++ dwi_ml/models/projects/learn2track_model.py | 4 ++++ dwi_ml/models/projects/transformer_models.py | 8 +++++--- scripts_python/ae_train_model.py | 21 ++++++++++++-------- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index ecbce674..a9c8d9ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -116,6 +116,10 @@ def forward(self, Batch of streamlines. Only used if previous directions are added to the model. Used to compute directions; its last point will not be used. + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. Returns ------- diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index d3b11237..f4576cdc 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -244,6 +244,10 @@ def forward(self, x: List[torch.tensor], Batch of streamlines. Only used if previous directions are added to the model. Used to compute directions; its last point will not be used. + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. hidden_recurrent_states : list[states] The current hidden states of the (stacked) RNN model. return_hidden: bool diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 7f883778..aa7b2847 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -360,8 +360,7 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len): def forward(self, inputs: List[torch.tensor], input_streamlines: List[torch.tensor] = None, - data_per_streamline: Union[List[torch.tensor], - List[np.ndarray]] = None): + data_per_streamline: dict = None): """ Params ------ @@ -379,7 +378,10 @@ def forward(self, inputs: List[torch.tensor], adequately masked to hide future positions. The last direction is not used. - As target during training. The whole sequence is used. - + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. Returns ------- output: Tensor, diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 184ef40f..7a068712 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -43,13 +43,18 @@ def prepare_arg_parser(): p.add_argument('streamline_group_name', help="Name of the group in hdf5") p.add_argument('--viz_latent_space', action='store_true', default=False, - help="Frequency at which to visualize latent space.\n" - "This is expressed in number of epochs.") - p.add_argument('--color_by', type=str, default=None, + help="If specified, enables latent space visualization.\n") + p.add_argument('--viz_color_by', type=str, default=None, help="Name of the group in hdf5 to color by." "In the HDF5, the coloring group should be under" "data_per_streamline/") - p.add_argument('--bundles_mapping', type=str, default=None, + p.add_argument('--viz_max_bundle_size', type=int, default=1000, + help="Maximum number of streamlines per bundle to " + "visualize in latent space. Will perform a random\n" + "selection if the number of streamlines is higher than " + "this value." + ) + p.add_argument('--viz_bundles_mapping', type=str, default=None, help="Path to a txt file mapping bundles to a new name.\n" "Each line of that file should be: ") add_memory_args(p, add_lazy_options=True, add_rng=True) @@ -62,7 +67,7 @@ def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed viz_latent_space = args.viz_latent_space - color_by = args.color_by + viz_color_by = args.viz_color_by # Prepare the dataset dataset = prepare_multisubjectdataset(args, load_testing=False, @@ -117,9 +122,9 @@ def init_from_args(args, sub_loggers_level): # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, log_level=sub_loggers_level, - viz_latent_space=viz_latent_space, color_by=color_by, - bundles_mapping_file=args.bundles_mapping, - max_viz_subset_size=1000) + viz_latent_space=viz_latent_space, color_by=viz_color_by, + bundles_mapping_file=args.viz_bundles_mapping, + max_viz_subset_size=args.viz_max_bundle_size) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) From 26ded25e72a19a1e7a403521f665f6488f897b3a Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 15:39:12 -0400 Subject: [PATCH 20/28] Fix missing import --- dwi_ml/viz/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index ec02b71d..895c4b2c 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -1,6 +1,7 @@ import math from matplotlib.colors import ListedColormap from matplotlib.cm import hsv +import numpy as np def generate_dissimilar_color_map(nb_distinct_colors: int): From c51c7a36ee9b9de68fcff4a61fcf1b8582767a49 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 7 Oct 2024 17:49:36 -0400 Subject: [PATCH 21/28] input francois code --- dwi_ml/viz/utils.py | 138 +++++++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 66 deletions(-) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index 895c4b2c..ac3a5fbd 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -1,75 +1,81 @@ -import math from matplotlib.colors import ListedColormap -from matplotlib.cm import hsv +from skimage.color import hsv2rgb, rgb2lab import numpy as np def generate_dissimilar_color_map(nb_distinct_colors: int): """ - Create a colormap with a number of distinct colors. - Needed to have bigger color maps for more bundles. + Select `nb_distinct_colors` dissimilar colors by sampling HSV values and computing distances in the CIELAB space. - Code directly copied from: - https://stackoverflow.com/questions/42697933 + Args: + nb_distinct_colors (int): nb_distinct of colors to select. + h_range (tuple): Range for the hue component (default is full range 0 to 1). + s_range (tuple): Range for the saturation component. + v_range (tuple): Range for the value component. + Returns: + np.ndarray: Array of selected RGB colors. """ - if nb_distinct_colors == 0: - nb_distinct_colors = 80 - - nb_of_shades = 7 - nb_of_distinct_colors_with_mult_of_shades = int( - math.ceil(nb_distinct_colors / nb_of_shades) - * nb_of_shades) - - # Create an array with uniformly drawn floats taken from <0, 1) - # partition - linearly_distributed_nums = np.arange( - nb_of_distinct_colors_with_mult_of_shades) / \ - nb_of_distinct_colors_with_mult_of_shades - - # We are going to reorganise monotonically growing numbers in such way - # that there will be single array with saw-like pattern but each saw - # tooth is slightly higher than the one before. First divide - # linearly_distributed_nums into nb_of_shades sub-arrays containing - # linearly distributed numbers. - arr_by_shade_rows = linearly_distributed_nums.reshape( - nb_of_shades, nb_of_distinct_colors_with_mult_of_shades // - nb_of_shades) - - # Transpose the above matrix (columns become rows) - as a result each - # row contains saw tooth with values slightly higher than row above - arr_by_shade_columns = arr_by_shade_rows.T - - # Keep number of saw teeth for later - number_of_partitions = arr_by_shade_columns.shape[0] - - # Flatten the above matrix - join each row into single array - nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) - - # HSV colour map is cyclic we'll use this property - # (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic) - initial_cm = hsv(nums_distributed_like_rising_saw) - - lower_partitions_half = number_of_partitions // 2 - upper_partitions_half = number_of_partitions - lower_partitions_half - - # Modify lower half in such way that colours towards beginning of - # partition are darker .First colours are affected more, colours - # closer to the middle are affected less - lower_half = lower_partitions_half * nb_of_shades - for i in range(3): - initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) - - # Modify second half in such way that colours towards end of partition - # are less intense and brighter. Colours closer to the middle are - # affected less, colours closer to the end are affected more - for i in range(3): - for j in range(upper_partitions_half): - modifier = np.ones(nb_of_shades) \ - - initial_cm[lower_half + j * nb_of_shades: - lower_half + (j + 1) * nb_of_shades, i] - modifier = j * modifier / upper_partitions_half - initial_cm[lower_half + j * nb_of_shades: lower_half + - (j + 1) * nb_of_shades, i] += modifier - - return ListedColormap(initial_cm) + h_range=(0, 1) + s_range=(0.25, 1) + v_range=(0.25, 1) + + # Start with a random initial color + rgb_colors = [[1, 0, 0]] + while len(rgb_colors) < nb_distinct_colors: + max_distance = -1 + best_color = None + + # Randomly generate a candidate color in HSV + #for _ in range(100): # Generate 100 candidates and pick the best one + hue = np.random.uniform(h_range[0], h_range[1], 100) + saturation = np.random.uniform(s_range[0], s_range[1], 100) + value = np.random.uniform(v_range[0], v_range[1], 100) + candidate_hsv = np.stack([hue, saturation, value], axis=1) + candidate_rgb = hsv2rgb(candidate_hsv) + + # Compute the minimum distance to any selected color in LAB space + distance = compute_cielab_distances(candidate_rgb, rgb_colors) + distance = np.min(distance, axis=1) + min_distance = np.max(distance) + min_distance_id = np.argmax(distance) + + if min_distance > max_distance: + max_distance = min_distance + best_color = candidate_rgb[min_distance_id] + + rgb_colors = np.vstack([rgb_colors, best_color]) + + + return ListedColormap(rgb_colors) + +def compute_cielab_distances(rgb_colors, compared_to=None): + """ + Convert RGB colors to CIELAB and compute the Delta E (CIEDE2000) distance matrix. + + Args: + rgb_colors (np.ndarray): Array of RGB colors. + compared_to (np.ndarray): Array of RGB colors to compare against. If None, compare to rgb_colors. + + Returns: + np.ndarray: nb_sample x nb_sample or nb_sample1 x nb_sample2 distance matrix. + """ + # Convert RGB to CIELAB + rgb_colors = np.clip(rgb_colors, 0, 1).astype(float) + lab_colors_1 = rgb2lab(rgb_colors) + + if compared_to is None: + lab_colors_2 = lab_colors_1 + else: + compared_to = np.clip(compared_to, 0, 1).astype(float) + lab_colors_2 = rgb2lab(compared_to) + + # Calculate the pairwise Delta E distances using broadcasting and vectorization + lab_colors_1 = lab_colors_1[:, np.newaxis, :] # Shape (n1, 1, 3) + lab_colors_2 = lab_colors_2[np.newaxis, :, :] # Shape (1, n2, 3) + + # Vectorized Delta E calculation + distance_matrix = deltaE_ciede2000(lab_colors_1, lab_colors_2, + kL=1, kC=1, kH=1) + + return distance_matrix \ No newline at end of file From f164a50363bbdbb813735f8563351f1c552f23e9 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 7 Oct 2024 18:21:55 -0400 Subject: [PATCH 22/28] fix import --- dwi_ml/viz/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index ac3a5fbd..b2cca729 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -1,5 +1,5 @@ from matplotlib.colors import ListedColormap -from skimage.color import hsv2rgb, rgb2lab +from skimage.color import hsv2rgb, rgb2lab, deltaE_ciede2000 import numpy as np From 3545c88347d876d134a982e22c20e74f0649d932 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 7 Oct 2024 20:24:35 -0400 Subject: [PATCH 23/28] pep8 --- dwi_ml/viz/utils.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index b2cca729..b137b198 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -5,20 +5,22 @@ def generate_dissimilar_color_map(nb_distinct_colors: int): """ - Select `nb_distinct_colors` dissimilar colors by sampling HSV values and computing distances in the CIELAB space. + Select `nb_distinct_colors` dissimilar colors by sampling HSV values and + computing distances in the CIELAB space. - Args: + Parameters: nb_distinct_colors (int): nb_distinct of colors to select. - h_range (tuple): Range for the hue component (default is full range 0 to 1). - s_range (tuple): Range for the saturation component. - v_range (tuple): Range for the value component. Returns: np.ndarray: Array of selected RGB colors. """ - h_range=(0, 1) - s_range=(0.25, 1) - v_range=(0.25, 1) + + # h_range (tuple): Range for the hue component. + # s_range (tuple): Range for the saturation component. + # v_range (tuple): Range for the value component. + h_range = (0, 1) + s_range = (0.25, 1) + v_range = (0.25, 1) # Start with a random initial color rgb_colors = [[1, 0, 0]] @@ -27,7 +29,7 @@ def generate_dissimilar_color_map(nb_distinct_colors: int): best_color = None # Randomly generate a candidate color in HSV - #for _ in range(100): # Generate 100 candidates and pick the best one + # for _ in range(100): # Generate 100 candidates and pick the best one hue = np.random.uniform(h_range[0], h_range[1], 100) saturation = np.random.uniform(s_range[0], s_range[1], 100) value = np.random.uniform(v_range[0], v_range[1], 100) @@ -46,19 +48,22 @@ def generate_dissimilar_color_map(nb_distinct_colors: int): rgb_colors = np.vstack([rgb_colors, best_color]) - return ListedColormap(rgb_colors) + def compute_cielab_distances(rgb_colors, compared_to=None): """ - Convert RGB colors to CIELAB and compute the Delta E (CIEDE2000) distance matrix. + Convert RGB colors to CIELAB and compute + the Delta E (CIEDE2000) distance matrix. - Args: + Parameters: rgb_colors (np.ndarray): Array of RGB colors. - compared_to (np.ndarray): Array of RGB colors to compare against. If None, compare to rgb_colors. + compared_to (np.ndarray): Array of RGB colors to compare against. + If None, compare to rgb_colors. Returns: - np.ndarray: nb_sample x nb_sample or nb_sample1 x nb_sample2 distance matrix. + np.ndarray: nb_sample x nb_sample or \ + nb_sample1 x nb_sample2 distance matrix. """ # Convert RGB to CIELAB rgb_colors = np.clip(rgb_colors, 0, 1).astype(float) @@ -70,7 +75,8 @@ def compute_cielab_distances(rgb_colors, compared_to=None): compared_to = np.clip(compared_to, 0, 1).astype(float) lab_colors_2 = rgb2lab(compared_to) - # Calculate the pairwise Delta E distances using broadcasting and vectorization + # Calculate the pairwise Delta E distances + # using broadcasting and vectorization lab_colors_1 = lab_colors_1[:, np.newaxis, :] # Shape (n1, 1, 3) lab_colors_2 = lab_colors_2[np.newaxis, :, :] # Shape (1, n2, 3) @@ -78,4 +84,4 @@ def compute_cielab_distances(rgb_colors, compared_to=None): distance_matrix = deltaE_ciede2000(lab_colors_1, lab_colors_2, kL=1, kC=1, kH=1) - return distance_matrix \ No newline at end of file + return distance_matrix From 31c7b710c5119f7956c7e7bdcbd1b2d545695a21 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 7 Oct 2024 20:34:17 -0400 Subject: [PATCH 24/28] remove comment - jeremi review --- dwi_ml/viz/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index b137b198..fb87d72d 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -29,7 +29,7 @@ def generate_dissimilar_color_map(nb_distinct_colors: int): best_color = None # Randomly generate a candidate color in HSV - # for _ in range(100): # Generate 100 candidates and pick the best one + # Generate 100 candidates and pick the best one hue = np.random.uniform(h_range[0], h_range[1], 100) saturation = np.random.uniform(s_range[0], s_range[1], 100) value = np.random.uniform(v_range[0], v_range[1], 100) From 5f0fae5196b60218374b2387fa5726357f77b2ae Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 8 Oct 2024 09:06:27 -0400 Subject: [PATCH 25/28] set higher s range and v range --- dwi_ml/viz/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py index fb87d72d..86132f73 100644 --- a/dwi_ml/viz/utils.py +++ b/dwi_ml/viz/utils.py @@ -19,8 +19,8 @@ def generate_dissimilar_color_map(nb_distinct_colors: int): # s_range (tuple): Range for the saturation component. # v_range (tuple): Range for the value component. h_range = (0, 1) - s_range = (0.25, 1) - v_range = (0.25, 1) + s_range = (0.8, 1) + v_range = (0.8, 1) # Start with a random initial color rgb_colors = [[1, 0, 0]] From 3ea27cb0766d2da3a5e4ca8cd6b9db476674461e Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 8 Nov 2024 11:14:34 -0500 Subject: [PATCH 26/28] incomplete: stick to master --- dwi_ml/data/dataset/streamline_containers.py | 25 ----------------- dwi_ml/data/hdf5/hdf5_creation.py | 1 - dwi_ml/training/trainers.py | 28 +++++++------------- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 65d239fe..c72a89bc 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -43,7 +43,6 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group): streamlines._data = np.array(hdf_group['data']) streamlines._offsets = np.array(hdf_group['offsets']) streamlines._lengths = np.array(hdf_group['lengths']) - dps_dict = _load_data_per_streamline(hdf_group) # DPS hdf_dps_group = hdf_group['data_per_streamline'] @@ -75,30 +74,6 @@ def _load_connectivity_info(hdf_group: h5py.Group): return contains_connectivity, connectivity_nb_blocs, connectivity_labels -def _load_data_per_streamline(hdf_group, - dps_key: str = None) -> Union[np.ndarray, None]: - dps_dict = defaultdict(list) - # Load only related data key if specified - if 'data_per_streamline' not in hdf_group.keys(): - return dps_dict - - dps_group = hdf_group['data_per_streamline'] - if dps_key is not None: - # Make sure the related data key is in the hdf5 group - if not (dps_key in dps_group.keys()): - raise KeyError("The key '{}' is not in the hdf5 group. Found: {}" - .format(dps_key, dps_group.keys())) - - # Load the related data per streamline - dps_dict[dps_key] = dps_group[dps_key][:] - # Otherwise, load every dps. - else: - for dps_key in dps_group.keys(): - dps_dict[dps_key] = dps_group[dps_key][:] - - return dps_dict - - class _LazyStreamlinesGetter(object): def __init__(self, hdf_group): self.hdf_group = hdf_group diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 375132b1..006b16fe 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -133,7 +133,6 @@ class HDF5Creator: See the doc for an example of config file. https://dwi-ml.readthedocs.io/en/latest/config_file.html """ - def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 9b19a8d6..15621360 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -50,7 +50,6 @@ class DWIMLAbstractTrainer: NOTE: TRAINER USES STREAMLINES COORDINATES IN VOXEL SPACE, CORNER ORIGIN. """ - def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, @@ -518,10 +517,8 @@ def _update_states_from_checkpoint(self, current_states): # A. Rng value. # RNG: # - numpy - self.batch_sampler.np_rng.set_state( - current_states['sampler_np_rng_state']) - self.batch_loader.np_rng.set_state( - current_states['loader_np_rng_state']) + self.batch_sampler.np_rng.set_state(current_states['sampler_np_rng_state']) + self.batch_loader.np_rng.set_state(current_states['loader_np_rng_state']) # - torch torch.set_rng_state(current_states['torch_rng_state']) if self.use_gpu: @@ -571,13 +568,10 @@ def _init_comet(self): display_summary_level=False) self.comet_exp.set_name(self.experiment_name) self.comet_exp.log_parameters(self.params_for_checkpoint) - self.comet_exp.log_parameters( - self.batch_sampler.params_for_checkpoint) - self.comet_exp.log_parameters( - self.batch_loader.params_for_checkpoint) + self.comet_exp.log_parameters(self.batch_sampler.params_for_checkpoint) + self.comet_exp.log_parameters(self.batch_loader.params_for_checkpoint) self.comet_exp.log_parameters(self.model.params_for_checkpoint) - self.comet_exp.log_parameters( - self.model.computed_params_for_display) + self.comet_exp.log_parameters(self.model.computed_params_for_display) self.comet_key = self.comet_exp.get_key() # Couldn't find how to set log level. Getting it directly. comet_log = logging.getLogger("comet_ml") @@ -888,8 +882,7 @@ def validate_one_epoch(self, epoch): # Validate this batch: forward propagation + loss with torch.no_grad(): - self.validate_one_batch( - data, epoch) + self.validate_one_batch(data, epoch) # Break if maximum number of epochs has been reached if batch_id == self.nb_batches_valid - 1: @@ -1022,7 +1015,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj, data_per_streamline = data + targets, ids_per_subj = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1044,7 +1037,7 @@ def run_one_batch(self, data): # but ok, shouldn't be too heavy. Easier to deal with multiple # projects' requirements by sending whole streamlines rather # than only directions. - model_outputs = self.model(streamlines_f, data_per_streamline) + model_outputs = self.model(streamlines_f) del streamlines_f logger.debug('*** Computing loss') @@ -1157,7 +1150,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj, data_per_streamline = data + targets, ids_per_subj = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1185,8 +1178,7 @@ def run_one_batch(self, data): # (batch loader will do it depending on training / valid) streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) - model_outputs = self.model( - batch_inputs, streamlines_f, data_per_streamline) + model_outputs = self.model(batch_inputs, streamlines_f) del streamlines_f logger.debug('*** Computing loss') From 3669453f43d6b5008ed9b8d84b7bd28229279722 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 8 Nov 2024 11:15:28 -0500 Subject: [PATCH 27/28] incomplete: stick to master part 2 --- dwi_ml/training/trainers.py | 9 +++++---- dwi_ml/training/utils/monitoring.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 15621360..c7ed4537 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1015,7 +1015,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1037,7 +1037,7 @@ def run_one_batch(self, data): # but ok, shouldn't be too heavy. Easier to deal with multiple # projects' requirements by sending whole streamlines rather # than only directions. - model_outputs = self.model(streamlines_f) + model_outputs = self.model(streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') @@ -1150,7 +1150,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1178,7 +1178,8 @@ def run_one_batch(self, data): # (batch loader will do it depending on training / valid) streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) - model_outputs = self.model(batch_inputs, streamlines_f) + model_outputs = self.model( + batch_inputs, streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') diff --git a/dwi_ml/training/utils/monitoring.py b/dwi_ml/training/utils/monitoring.py index b4ac5991..c4d05d27 100644 --- a/dwi_ml/training/utils/monitoring.py +++ b/dwi_ml/training/utils/monitoring.py @@ -252,7 +252,6 @@ class IterTimer(object): # next iter could be twice as long as usual: time.time() + iter_timer.mean * 2.0 + 30 > max_time """ - def __init__(self, history_len=5): self.history = deque(maxlen=history_len) self.iterable = None From 6bbc096c63e991439ef7508a52abe5a0ad9283c7 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 8 Nov 2024 11:45:16 -0500 Subject: [PATCH 28/28] fix: dps batch loading and stick to master 3 --- dwi_ml/models/projects/ae_models.py | 1 - dwi_ml/models/projects/learn2track_model.py | 3 +-- dwi_ml/models/projects/transformer_models.py | 3 --- dwi_ml/training/batch_loaders.py | 5 +++++ 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index a9c8d9ad..a222b3dd 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -18,7 +18,6 @@ class ModelAE(MainModelAbstract): deterministic (3D vectors) or probabilistic (based on probability distribution parameters). """ - def __init__(self, experiment_name: str, step_size: float = None, diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index f4576cdc..63758bb6 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -289,8 +289,7 @@ def forward(self, x: List[torch.tensor], unsorted_indices = invert_permutation(sorted_indices) x = [x[i] for i in sorted_indices] if input_streamlines is not None: - input_streamlines = [input_streamlines[i] - for i in sorted_indices] + input_streamlines = [input_streamlines[i] for i in sorted_indices] # ==== 0. Previous dirs. n_prev_dirs = None diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index aa7b2847..7bc21433 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -117,7 +117,6 @@ class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter, https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ the embedding probably adapts to leave place for the positional encoding. """ - def __init__(self, experiment_name: str, # Target preprocessing params for the batch loader + tracker @@ -828,7 +827,6 @@ class OriginalTransformerModel(AbstractTransformerModelWithTarget): emb_choice_x """ - def __init__(self, input_embedded_size, n_layers_d: int, **kw): """ d_model = input_embedded_size = target_embedded_size. @@ -970,7 +968,6 @@ class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget): [ emb_choice_x ; emb_choice_y ] """ - def __init__(self, **kw): """ No additional params. d_model = input size + target size. diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 7cb228c0..d46b6274 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -302,6 +302,7 @@ def load_batch_streamlines( # the loaded, processed streamlines, not to the ids in the hdf5 file. final_s_ids_per_subj = defaultdict(slice) batch_streamlines = [] + batch_dps = defaultdict(list) for subj, s_ids in streamline_ids_per_subj: logger.debug( " Data loader: Processing data preparation for " @@ -332,6 +333,10 @@ def load_batch_streamlines( sft.to_corner() batch_streamlines.extend(sft.streamlines) + # Add data per streamline for the batch elements + for key, value in sft.data_per_streamline.items(): + batch_dps[key].extend(value) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] data_per_streamline = _dps_to_tensors(sft.data_per_streamline)