From a3a2c32b6b1c32ce5b69465c4e7c75b66fb05605 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 26 Sep 2023 11:36:49 -0400 Subject: [PATCH] Add noise_loss param for deprecated batch loaders --- dwi_ml/training/batch_loaders.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 0173caaf..2161b410 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -161,6 +161,12 @@ def params_for_checkpoint(self): @classmethod def init_from_checkpoint(cls, dataset, model, checkpoint_state, new_log_level): + # Adding noise_gaussian_size_loss for deprecated batch loaders + if 'noise_gaussian_size_loss' not in checkpoint_state: + logging.warning("Deprecated batch loader. Did not contain a " + "noise_gaussian_size_loss value. Setting to 0.0.") + checkpoint_state['noise_gaussian_size_loss'] = 0.0 + batch_loader = cls(dataset=dataset, model=model, log_level=new_log_level, **checkpoint_state) return batch_loader