diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index eff22243..a222b3dd 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -35,6 +35,7 @@ def __init__(self, self.latent_space_dims = 32 self.pad = torch.nn.ReflectionPad1d(1) + self.post_encoding_hooks = [] def pre_pad(m): return torch.nn.Sequential(self.pad, m) @@ -104,6 +105,7 @@ def pre_pad(m): def forward(self, input_streamlines: List[torch.tensor], + data_per_streamline: dict = None ): """Run the model on a batch of sequences. @@ -113,6 +115,10 @@ def forward(self, Batch of streamlines. Only used if previous directions are added to the model. Used to compute directions; its last point will not be used. + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. Returns ------- @@ -121,12 +127,20 @@ def forward(self, `get_tracking_directions()`. """ - x = self.decode(self.encode(input_streamlines)) + encoded = self.encode(input_streamlines) + + for hook in self.post_encoding_hooks: + hook(encoded, data_per_streamline) + + x = self.decode(encoded) return x def encode(self, x): - # x: list of tensors - x = torch.stack(x) + # X input shape is (batch_size, nb_points, 3) + if isinstance(x, list): + x = torch.stack(x) + + # input of the network should be (N, 3, nb_points) x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -171,3 +185,6 @@ def compute_loss(self, model_outputs, targets, average_results=True): reconstruction_loss = torch.nn.MSELoss(reduction="sum") mse = reconstruction_loss(model_outputs, targets) return mse, 1 + + def register_hook_post_encoding(self, hook): + self.post_encoding_hooks.append(hook) diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 9ba8074c..63758bb6 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -227,6 +227,7 @@ def computed_params_for_display(self): def forward(self, x: List[torch.tensor], input_streamlines: List[torch.tensor] = None, + data_per_streamline: List[torch.tensor] = {}, hidden_recurrent_states: List = None, return_hidden=False, point_idx: int = None): """Run the model on a batch of sequences. @@ -243,6 +244,10 @@ def forward(self, x: List[torch.tensor], Batch of streamlines. Only used if previous directions are added to the model. Used to compute directions; its last point will not be used. + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. hidden_recurrent_states : list[states] The current hidden states of the (stacked) RNN model. return_hidden: bool diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index a93a5c8e..7bc21433 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -358,7 +358,8 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len): return mask_future, mask_padding def forward(self, inputs: List[torch.tensor], - input_streamlines: List[torch.tensor] = None): + input_streamlines: List[torch.tensor] = None, + data_per_streamline: dict = None): """ Params ------ @@ -376,7 +377,10 @@ def forward(self, inputs: List[torch.tensor], adequately masked to hide future positions. The last direction is not used. - As target during training. The whole sequence is used. - + data_per_streamline: dict of lists, optional + Dictionary containing additional data for each streamline. Each + key is the name of a data type, and each value is a list of length + `len(input_streamlines)` containing the data for each streamline. Returns ------- output: Tensor, diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index f8a6732d..d46b6274 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,6 +58,16 @@ logger = logging.getLogger('batch_loader_logger') +def _dps_to_tensors(dps: dict, device='cpu'): + """ + Convert a list of DPS to a list of tensors. + """ + dps_tensors = {} + for key, value in dps.items(): + dps_tensors[key] = torch.tensor(value, device=device) + return dps_tensors + + class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, @@ -157,7 +167,7 @@ def params_for_checkpoint(self): 'noise_gaussian_size_forward': self.noise_gaussian_size_forward, 'noise_gaussian_size_loss': self.noise_gaussian_size_loss, 'reverse_ratio': self.reverse_ratio, - 'split_ratio': self.split_ratio, + 'split_ratio': self.split_ratio } return params @@ -292,6 +302,7 @@ def load_batch_streamlines( # the loaded, processed streamlines, not to the ids in the hdf5 file. final_s_ids_per_subj = defaultdict(slice) batch_streamlines = [] + batch_dps = defaultdict(list) for subj, s_ids in streamline_ids_per_subj: logger.debug( " Data loader: Processing data preparation for " @@ -322,9 +333,14 @@ def load_batch_streamlines( sft.to_corner() batch_streamlines.extend(sft.streamlines) + # Add data per streamline for the batch elements + for key, value in sft.data_per_streamline.items(): + batch_dps[key].extend(value) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] + data_per_streamline = _dps_to_tensors(sft.data_per_streamline) - return batch_streamlines, final_s_ids_per_subj + return batch_streamlines, final_s_ids_per_subj, data_per_streamline def load_batch_connectivity_matrices( self, streamline_ids_per_subj: Dict[int, slice]): diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py new file mode 100644 index 00000000..e74abba5 --- /dev/null +++ b/dwi_ml/training/projects/ae_trainer.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +import logging +import os +from typing import Union, List + +from dwi_ml.models.main_models import MainModelAbstract +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + +LOGGER = logging.getLogger(__name__) + + +def parse_bundle_mapping(bundles_mapping_file: str = None): + if bundles_mapping_file is None: + return None + + with open(bundles_mapping_file, 'r') as f: + bundle_mapping = {} + for line in f: + bundle_name, bundle_number = line.strip().split() + bundle_mapping[int(bundle_number)] = bundle_name + return bundle_mapping + + +class TrainerWithBundleDPS(DWIMLAbstractTrainer): + + def __init__(self, + model: MainModelAbstract, experiments_path: str, + experiment_name: str, batch_sampler: DWIMLBatchIDSampler, + batch_loader: DWIMLStreamlinesBatchLoader, + learning_rates: Union[List, float] = None, + weight_decay: float = 0.01, + optimizer: str = 'Adam', max_epochs: int = 10, + max_batches_per_epoch_training: int = 1000, + max_batches_per_epoch_validation: Union[int, None] = 1000, + patience: int = None, patience_delta: float = 1e-6, + nb_cpu_processes: int = 0, use_gpu: bool = False, + clip_grad: float = None, + comet_workspace: str = None, comet_project: str = None, + from_checkpoint: bool = False, log_level=logging.root.level, + viz_latent_space: bool = False, color_by: str = None, + bundles_mapping_file: str = None, + max_viz_subset_size: int = 1000): + + super().__init__(model, experiments_path, experiment_name, + batch_sampler, batch_loader, learning_rates, + weight_decay, optimizer, max_epochs, + max_batches_per_epoch_training, + max_batches_per_epoch_validation, patience, + patience_delta, nb_cpu_processes, use_gpu, + clip_grad, comet_workspace, comet_project, + from_checkpoint, log_level) + + self.color_by = color_by + self.viz_latent_space = viz_latent_space + if self.viz_latent_space: + # Setup to visualize latent space + save_dir = os.path.join( + experiments_path, experiment_name, 'latent_space_plots') + os.makedirs(save_dir, exist_ok=True) + + bundle_mapping = parse_bundle_mapping(bundles_mapping_file) + self.ls_viz = BundlesLatentSpaceVisualizer( + save_dir, + prefix_numbering=True, + max_subset_size=max_viz_subset_size, + bundle_mapping=bundle_mapping) + self.warning_printed = False + + # Register what to do post encoding. + def handle_latent_encodings(encoding, data_per_streamline): + # Only accumulate data during training + if not self.model.context == 'training': + return + + if self.color_by is None: + bundle_index = None + elif self.color_by not in data_per_streamline.keys(): + if not self.warning_printed: + LOGGER.warning( + f"Coloring by {self.color_by} not found in " + "data_per_streamline.") + self.warning_printed = True + bundle_index = None + else: + bundle_index = \ + data_per_streamline[self.color_by].squeeze(1) + + self.ls_viz.add_data_to_plot(encoding, labels=bundle_index) + # Execute the above function within the model's forward(). + model.register_hook_post_encoding(handle_latent_encodings) + + # Plot the latent space after each best epoch. + # Called after running training & validation epochs. + self.best_epoch_monitor.register_on_best_epoch_hook( + self.ls_viz.plot) + + def train_one_epoch(self, epoch): + if self.viz_latent_space: + # Before starting another training epoch, make sure the data + # is cleared. This is important to avoid accumulating data. + # We have to do it here. Since the on_new_best_epoch is called + # after the validation epoch, we can't do it there. + # Also, we won't always have the best epoch, if not, we still need + # to clear the data. + self.ls_viz.reset_data() + + super().train_one_epoch(epoch) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 15621360..c7ed4537 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1015,7 +1015,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1037,7 +1037,7 @@ def run_one_batch(self, data): # but ok, shouldn't be too heavy. Easier to deal with multiple # projects' requirements by sending whole streamlines rather # than only directions. - model_outputs = self.model(streamlines_f) + model_outputs = self.model(streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') @@ -1150,7 +1150,7 @@ def run_one_batch(self, data): """ # Data interpolation has not been done yet. GPU computations are done # here in the main thread. - targets, ids_per_subj = data + targets, ids_per_subj, data_per_streamline = data # Dataloader always works on CPU. Sending to right device. # (model is already moved). @@ -1178,7 +1178,8 @@ def run_one_batch(self, data): # (batch loader will do it depending on training / valid) streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) - model_outputs = self.model(batch_inputs, streamlines_f) + model_outputs = self.model( + batch_inputs, streamlines_f, data_per_streamline) del streamlines_f logger.debug('*** Computing loss') diff --git a/dwi_ml/training/trainers_withGV.py b/dwi_ml/training/trainers_withGV.py index a0aebfcb..c42ee214 100644 --- a/dwi_ml/training/trainers_withGV.py +++ b/dwi_ml/training/trainers_withGV.py @@ -242,7 +242,7 @@ def gv_phase_one_batch(self, data, compute_all_scores=False): seeds and first few segments. Expected results are the batch's validation streamlines. """ - real_lines, ids_per_subj = data + real_lines, ids_per_subj, data_per_streamline = data # Possibly sending again to GPU even if done in the local loss # computation, but easier with current implementation. diff --git a/dwi_ml/training/utils/monitoring.py b/dwi_ml/training/utils/monitoring.py index 79086528..c4d05d27 100644 --- a/dwi_ml/training/utils/monitoring.py +++ b/dwi_ml/training/utils/monitoring.py @@ -158,6 +158,14 @@ def __init__(self, name, patience: int, patience_delta: float = 1e-6): self.best_value = None self.best_epoch = None self.n_bad_epochs = None + self.on_best_epoch_hooks = [] + + def register_on_best_epoch_hook(self, hook): + self.on_best_epoch_hooks.append(hook) + + def _call_on_best_epoch_hooks(self, new_best_epoch): + for hook in self.on_best_epoch_hooks: + hook(new_best_epoch) def update(self, loss, epoch): """ @@ -178,12 +186,14 @@ def update(self, loss, epoch): self.best_value = loss self.best_epoch = epoch self.n_bad_epochs = 0 + self._call_on_best_epoch_hooks(epoch) return False elif loss < self.best_value - self.min_eps: # Improving from at least eps. self.best_value = loss self.best_epoch = epoch self.n_bad_epochs = 0 + self._call_on_best_epoch_hooks(epoch) return False else: # Not improving enough diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index eb15b5b6..e0a1df2a 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -84,7 +84,7 @@ def compute_loss(self, model_outputs, target_streamlines=None, else: return torch.zeros(n, device=self.device), 1 - def forward(self, inputs: list, streamlines): + def forward(self, inputs: list, streamlines, data_per_streamline): # Not using streamlines. Pretending to use inputs. _ = self.fake_parameter regressed_dir = torch.as_tensor([1., 1., 1.]) @@ -143,7 +143,8 @@ def get_tracking_directions(self, regressed_dirs, algo, raise ValueError("'algo' should be 'det' or 'prob'.") def forward(self, inputs: List[torch.tensor], - target_streamlines: List[torch.tensor]): + target_streamlines: List[torch.tensor], + data_per_streamline: List[torch.tensor]): # Previous dirs if self.nb_previous_dirs > 0: target_dirs = compute_directions(target_streamlines) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py new file mode 100644 index 00000000..6f5f7114 --- /dev/null +++ b/dwi_ml/viz/latent_streamlines.py @@ -0,0 +1,310 @@ +import os +import logging +from typing import Union, List, Tuple +from sklearn.manifold import TSNE +import numpy as np +import torch +import matplotlib.pyplot as plt +from dwi_ml.viz.utils import generate_dissimilar_color_map + +LOGGER = logging.getLogger(__name__) + +DEFAULT_BUNDLE_NAME = 'UNK' + + +class ColorManager(object): + """ + Utility class to manage the color of the bundles in the latent space. + This way, we can have a consistent color for each bundle across + different plots. + """ + + def __init__(self, max_num_bundles: int = 40): + self.bundle_color_map = {} + self.color_map = generate_dissimilar_color_map(max_num_bundles) + + def get_color(self, label: str): + if label not in self.bundle_color_map: + self.bundle_color_map[label] = \ + self.color_map( + len(self.bundle_color_map)) + + return self.bundle_color_map[label] + + +class BundlesLatentSpaceVisualizer(object): + """ + Utility class that wraps a t-SNE projection of the latent + space for multiple bundles. The usage of this class is + intented as follows: + 1. Create an instance of this class, + 2. Add the latent space streamlines for each bundle + using "add_data_to_plot" with its corresponding label. + 3. Fit and plot the t-SNE projection using the "plot" method. + + t-SNE projection can only leverage the fit_transform() with all + the data that needs to be projected at the same time since it aims + to preserve the local structure of the data. + """ + + def __init__(self, + save_dir: str, + fig_size: Union[List, Tuple] = (10, 8), + random_state: int = 42, + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True, + bundle_mapping: dict = None + ): + """ + Parameters + ---------- + save_path: str + Path to save the figure. If not specified, the figure will show. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: List + Random state for t-SNE. + max_subset_size: int + In case of performance issues, you can limit the number of + streamlines to plot for each bundle. + prefix_numbering: bool + If True, the saved figures will be numbered with the current + plot number. The plot number is incremented after each call + to the "plot" method. + reset_warning: bool + If True, a warning will be displayed when calling "plot"several + times without calling "reset_data" in between to clear the data. + bundle_mapping: dict + Mapping of the bundle names to the labels to display on the plot. + (e.g. key: bundle index, value: bundle name) + """ + self.save_dir = save_dir + + # Make sure that self.save_dir is a directory and exists. + if not os.path.isdir(self.save_dir): + raise ValueError("The save_dir should be a directory.") + + self.fig_size = fig_size + self.random_state = random_state + self.max_subset_size = max_subset_size + if not (self.max_subset_size > 0): + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + + self.prefix_numbering = prefix_numbering + self.reset_warning = reset_warning + self.bundle_mapping = bundle_mapping + + self.current_plot_number = 0 + self.should_call_reset_before_plot = False + + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + self.bundle_color_manager = ColorManager() + + self.fig, self.ax = None, None + + def reset_data(self): + """ + Reset the data to plot. If you call plot several times without + calling this method, the data will be accumulated. + """ + # Not sure if resetting the TSNE object is necessary. + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + self.should_call_reset_before_plot = False + + def add_data_to_plot(self, data: np.ndarray, labels: List[str]): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + labels: np.ndarray + Labels for each streamline. + """ + latent_space_streamlines = self._to_numpy(data) + if labels is None: + self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines + else: + all_labels = np.unique(labels) + remaining_indices = np.arange(len(labels)) + for label in all_labels: + label_indices = labels[remaining_indices] == label + label_data = \ + latent_space_streamlines[remaining_indices][label_indices] + label_data = self._resample_max_subset_size(label_data) + self.bundles[label] = label_data + + remaining_indices = remaining_indices[~label_indices] + + if len(remaining_indices) > 0: + LOGGER.warning( + "Some streamlines were not considered in the bundles," + "some labels are missing.\n" + "Added them to the {} bundle." + .format(DEFAULT_BUNDLE_NAME)) + self.bundles[DEFAULT_BUNDLE_NAME] = \ + latent_space_streamlines[remaining_indices] + + def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + latent_space_streamlines = self._to_numpy(data) + latent_space_streamlines = self._resample_max_subset_size( + latent_space_streamlines) + + self.bundles[label] = latent_space_streamlines + + def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'): + """ + Fit and plot the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using + "add_data_to_plot". + + Parameters + ---------- + figure_name_prefix: str + Name of the figure to be saved. This is just the prefix of the + full file name as it will be suffixed with the current plot + number if enabled. + """ + if self.should_call_reset_before_plot and self.reset_warning: + LOGGER.warning( + "You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to " + "unexpected results.") + self.should_call_reset_before_plot = False + elif not self.current_plot_number > 0: + # Only enable the flag for the first plot. + # So that the warning above is only displayed once. + self.should_call_reset_before_plot = True + + # Start by making sure the number of streamlines doesn't + # exceed the threshold. + for (bname, bdata) in self.bundles.items(): + if bdata.shape[0] > self.max_subset_size: + self.bundles[bname] = self._resample_max_subset_size(bdata) + + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + LOGGER.info( + "Plotting a total of {} streamlines".format(nb_streamlines)) + + # Build the indices for each bundle to recover the streamlines after + # the t-SNE projection. + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange( + current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + LOGGER.info("Fitting TSNE projection.") + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + if self.fig is None or self.ax is None: + self.fig, self.ax = self._init_figure() + + self.ax.clear() + + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + blabel = self.bundle_mapping.get( + bname, bname) if self.bundle_mapping else bname + + self._plot_bundle( + self.ax, proj_data[:, 0], proj_data[:, 1], blabel) + + self.ax.set_title("Best epoch {}".format(epoch)) + self._set_legend(self.ax, len(self.bundles)) + + if self.prefix_numbering: + filename = '{}_{}.png'.format( + figure_name_prefix, epoch) + else: + filename = '{}.png'.format(figure_name_prefix) + + filename = os.path.join(self.save_dir, filename) + self.fig.savefig(filename) + + self.current_plot_number += 1 + + def _set_legend(self, ax, nb_bundles, order=True): + if nb_bundles > 1: + handles, labels = ax.get_legend_handles_labels() + if order: + labels, handles = zip( + *sorted(zip(labels, handles), key=lambda t: t[0])) + ax.legend(handles, labels, fontsize=6, + loc='center left', bbox_to_anchor=(1, 0.5)) + + def _plot_bundle(self, ax, dim1, dim2, blabel): + ax.scatter( + dim1, + dim2, + label=blabel, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + color=self.bundle_color_manager.get_color(blabel) + ) + + def _init_figure(self): + LOGGER.info("Init new figure for BundlesLatentSpaceVisualizer.") + fig, ax = plt.subplots(1, 1) + ax.set_title("Best epoch (?)") + if self.fig_size is not None: + fig.set_figwidth(self.fig_size[0]) + fig.set_figheight(self.fig_size[1]) + + box_0 = ax.get_position() + ax.set_position( + [box_0.x0, box_0.y0, box_0.width * 0.8, box_0.height]) + + return fig, ax + + def _to_numpy(self, data): + if isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + else: + return data + + def _resample_max_subset_size(self, data: np.ndarray): + """ + Resample the data to the max_subset_size. + """ + _resampled = data + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + + # Only sample if we need to reduce the number of latent + # streamlines to show on the plot. + if (len(data) > self.max_subset_size): + sample_indices = np.random.choice( + len(data), + self.max_subset_size, replace=False) + + _resampled = data[sample_indices] + + return _resampled diff --git a/dwi_ml/viz/utils.py b/dwi_ml/viz/utils.py new file mode 100644 index 00000000..86132f73 --- /dev/null +++ b/dwi_ml/viz/utils.py @@ -0,0 +1,87 @@ +from matplotlib.colors import ListedColormap +from skimage.color import hsv2rgb, rgb2lab, deltaE_ciede2000 +import numpy as np + + +def generate_dissimilar_color_map(nb_distinct_colors: int): + """ + Select `nb_distinct_colors` dissimilar colors by sampling HSV values and + computing distances in the CIELAB space. + + Parameters: + nb_distinct_colors (int): nb_distinct of colors to select. + + Returns: + np.ndarray: Array of selected RGB colors. + """ + + # h_range (tuple): Range for the hue component. + # s_range (tuple): Range for the saturation component. + # v_range (tuple): Range for the value component. + h_range = (0, 1) + s_range = (0.8, 1) + v_range = (0.8, 1) + + # Start with a random initial color + rgb_colors = [[1, 0, 0]] + while len(rgb_colors) < nb_distinct_colors: + max_distance = -1 + best_color = None + + # Randomly generate a candidate color in HSV + # Generate 100 candidates and pick the best one + hue = np.random.uniform(h_range[0], h_range[1], 100) + saturation = np.random.uniform(s_range[0], s_range[1], 100) + value = np.random.uniform(v_range[0], v_range[1], 100) + candidate_hsv = np.stack([hue, saturation, value], axis=1) + candidate_rgb = hsv2rgb(candidate_hsv) + + # Compute the minimum distance to any selected color in LAB space + distance = compute_cielab_distances(candidate_rgb, rgb_colors) + distance = np.min(distance, axis=1) + min_distance = np.max(distance) + min_distance_id = np.argmax(distance) + + if min_distance > max_distance: + max_distance = min_distance + best_color = candidate_rgb[min_distance_id] + + rgb_colors = np.vstack([rgb_colors, best_color]) + + return ListedColormap(rgb_colors) + + +def compute_cielab_distances(rgb_colors, compared_to=None): + """ + Convert RGB colors to CIELAB and compute + the Delta E (CIEDE2000) distance matrix. + + Parameters: + rgb_colors (np.ndarray): Array of RGB colors. + compared_to (np.ndarray): Array of RGB colors to compare against. + If None, compare to rgb_colors. + + Returns: + np.ndarray: nb_sample x nb_sample or \ + nb_sample1 x nb_sample2 distance matrix. + """ + # Convert RGB to CIELAB + rgb_colors = np.clip(rgb_colors, 0, 1).astype(float) + lab_colors_1 = rgb2lab(rgb_colors) + + if compared_to is None: + lab_colors_2 = lab_colors_1 + else: + compared_to = np.clip(compared_to, 0, 1).astype(float) + lab_colors_2 = rgb2lab(compared_to) + + # Calculate the pairwise Delta E distances + # using broadcasting and vectorization + lab_colors_1 = lab_colors_1[:, np.newaxis, :] # Shape (n1, 1, 3) + lab_colors_2 = lab_colors_2[np.newaxis, :, :] # Shape (1, n2, 3) + + # Vectorized Delta E calculation + distance_matrix = deltaE_ciede2000(lab_colors_1, lab_colors_2, + kL=1, kC=1, kH=1) + + return distance_matrix diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index e5846235..7a068712 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -22,7 +22,7 @@ from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.projects.ae_trainer import TrainerWithBundleDPS from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) @@ -42,6 +42,21 @@ def prepare_arg_parser(): add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") + p.add_argument('--viz_latent_space', action='store_true', default=False, + help="If specified, enables latent space visualization.\n") + p.add_argument('--viz_color_by', type=str, default=None, + help="Name of the group in hdf5 to color by." + "In the HDF5, the coloring group should be under" + "data_per_streamline/") + p.add_argument('--viz_max_bundle_size', type=int, default=1000, + help="Maximum number of streamlines per bundle to " + "visualize in latent space. Will perform a random\n" + "selection if the number of streamlines is higher than " + "this value." + ) + p.add_argument('--viz_bundles_mapping', type=str, default=None, + help="Path to a txt file mapping bundles to a new name.\n" + "Each line of that file should be: ") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) @@ -51,6 +66,9 @@ def prepare_arg_parser(): def init_from_args(args, sub_loggers_level): torch.manual_seed(args.rng) # Set torch seed + viz_latent_space = args.viz_latent_space + viz_color_by = args.viz_color_by + # Prepare the dataset dataset = prepare_multisubjectdataset(args, load_testing=False, log_level=sub_loggers_level) @@ -87,7 +105,7 @@ def init_from_args(args, sub_loggers_level): # Instantiate trainer with Timer("\n\nPreparing trainer", newline=True, color='red'): lr = format_lr(args.learning_rate) - trainer = DWIMLAbstractTrainer( + trainer = TrainerWithBundleDPS( model=model, experiments_path=args.experiments_path, experiment_name=args.experiment_name, batch_sampler=batch_sampler, batch_loader=batch_loader, @@ -103,7 +121,10 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=sub_loggers_level) + log_level=sub_loggers_level, + viz_latent_space=viz_latent_space, color_by=viz_color_by, + bundles_mapping_file=args.viz_bundles_mapping, + max_viz_subset_size=args.viz_max_bundle_size) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint))