From aaf48f27c523394911dae4502f89ead2111e6a18 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Mon, 27 May 2024 10:27:37 -0400 Subject: [PATCH] ENH: add clear cache opt to model --- dwi_ml/models/main_models.py | 8 +++++--- dwi_ml/training/batch_loaders.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 23c1f717..942aba46 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -450,7 +450,8 @@ def forward(self, inputs, target_streamlines: List[torch.tensor]): class MainModelOneInput(MainModelAbstract): def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, - subj_idx, input_group_idx, prepare_mask=False): + subj_idx, input_group_idx, prepare_mask=False, + clear_cache=True): """ These params are passed by either the batch loader or the propagator, which manage the data. @@ -491,10 +492,11 @@ def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, if isinstance(self, ModelWithNeighborhood): # Adding neighborhood. subj_x_data, coords_torch = interpolate_volume_in_neighborhood( - data_tensor, flat_subj_x_coords, self.neighborhood_vectors) + data_tensor, flat_subj_x_coords, self.neighborhood_vectors, + clear_cache=clear_cache) else: subj_x_data, coords_torch = interpolate_volume_in_neighborhood( - data_tensor, flat_subj_x_coords, None) + data_tensor, flat_subj_x_coords, None, clear_cache=clear_cache) # Split the flattened signal back to streamlines lengths = [len(s) for s in streamlines] diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index cd9c1f46..59db8d82 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -346,6 +346,7 @@ def __init__(self, input_group_name, **kw): .format(input_group_name, self.dataset.volume_groups)) self.input_group_idx = idx + self.clear_cache = True @property def params_for_checkpoint(self): @@ -404,7 +405,7 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor], # before adding streamline to batch. subbatch_x_data = self.model.prepare_batch_one_input( streamlines, self.context_subset, subj, - self.input_group_idx) + self.input_group_idx, clear_cache=self.clear_cache) batch_x_data.extend(subbatch_x_data)