From e6c7d6d69e8adf1e657c148e022d49538462ce46 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 26 Sep 2023 11:27:54 -0400 Subject: [PATCH 1/4] Formalize checkpoints for Batch Sampler --- dwi_ml/training/batch_samplers.py | 12 ++++++++++++ dwi_ml/training/utils/batch_samplers.py | 3 +-- .../l2t_resume_training_from_checkpoint.py | 6 +++--- scripts_python/tt_resume_training_from_checkpoint.py | 5 +++-- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dwi_ml/training/batch_samplers.py b/dwi_ml/training/batch_samplers.py index 3c9ec8ad..30d8d5ef 100644 --- a/dwi_ml/training/batch_samplers.py +++ b/dwi_ml/training/batch_samplers.py @@ -33,6 +33,7 @@ from torch.utils.data import Sampler from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.experiment_utils.prints import format_dict_to_str DEFAULT_CHUNK_SIZE = 256 logger = logging.getLogger('batch_sampler_logger') @@ -156,6 +157,17 @@ def params_for_checkpoint(self): } return params + @classmethod + def init_from_checkpoint(cls, dataset, checkpoint_state: dict, + new_log_level): + batch_sampler = cls(dataset=dataset, log_level=new_log_level, + **checkpoint_state) + + logging.info("Batch sampler's user-defined parameters: " + + format_dict_to_str(batch_sampler.params_for_checkpoint)) + + return batch_sampler + def set_context(self, context): if self.context != context: if context == 'training': diff --git a/dwi_ml/training/utils/batch_samplers.py b/dwi_ml/training/utils/batch_samplers.py index f1943470..bd57c8d5 100644 --- a/dwi_ml/training/utils/batch_samplers.py +++ b/dwi_ml/training/utils/batch_samplers.py @@ -56,7 +56,6 @@ def prepare_batch_sampler(dataset, args, sub_loggers_level): cycles=args.cycles, rng=args.rng, log_level=sub_loggers_level) - logging.info("Batch sampler's user-defined parameters: " + - format_dict_to_str(batch_sampler.params_for_checkpoint)) + return batch_sampler diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index 78c3ad9b..512b270e 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -13,9 +13,9 @@ from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_logging_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer from dwi_ml.training.utils.batch_loaders import prepare_batch_loader -from dwi_ml.training.utils.batch_samplers import prepare_batch_sampler from dwi_ml.training.utils.experiment import add_args_resuming_experiment from dwi_ml.training.utils.trainer import run_experiment @@ -57,8 +57,8 @@ def init_from_checkpoint(args, checkpoint_path): os.path.join(checkpoint_path, 'model'), sub_loggers_level) # Prepare batch sampler - _args = argparse.Namespace(**checkpoint_state['batch_sampler_params']) - batch_sampler = prepare_batch_sampler(dataset, _args, sub_loggers_level) + batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint( + dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader _args = argparse.Namespace(**checkpoint_state['batch_loader_params']) diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index 5307be44..047b6398 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -14,6 +14,7 @@ from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path from dwi_ml.models.projects.transformer_models import \ OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel +from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.transformer_trainer import TransformerTrainer from dwi_ml.training.utils.batch_samplers import prepare_batch_sampler from dwi_ml.training.utils.batch_loaders import prepare_batch_loader @@ -68,8 +69,8 @@ def init_from_checkpoint(args, checkpoint_path): model = cls.load_model_from_params_and_state(model_dir, sub_loggers_level) # Prepare batch sampler - _args = argparse.Namespace(**checkpoint_state['batch_sampler_params']) - batch_sampler = prepare_batch_sampler(dataset, _args, sub_loggers_level) + batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint( + dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader _args = argparse.Namespace(**checkpoint_state['batch_loader_params']) From 8cf0e82311f5efc54c200f952dcb5e0a544e760c Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 26 Sep 2023 11:35:04 -0400 Subject: [PATCH 2/4] Formalize checkpoint for batch_loader --- dwi_ml/training/batch_loaders.py | 7 +++++++ scripts_python/l2t_resume_training_from_checkpoint.py | 8 +++++--- scripts_python/tt_resume_training_from_checkpoint.py | 9 +++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 93574632..0173caaf 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -158,6 +158,13 @@ def params_for_checkpoint(self): } return params + @classmethod + def init_from_checkpoint(cls, dataset, model, checkpoint_state, + new_log_level): + batch_loader = cls(dataset=dataset, model=model, + log_level=new_log_level, **checkpoint_state) + return batch_loader + def set_context(self, context: str): if self.context != context: if context == 'training': diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index 512b270e..d4f91b54 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -15,9 +15,10 @@ from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer -from dwi_ml.training.utils.batch_loaders import prepare_batch_loader from dwi_ml.training.utils.experiment import add_args_resuming_experiment from dwi_ml.training.utils.trainer import run_experiment +from dwi_ml.training.with_generation.batch_loader import \ + DWIMLBatchLoaderWithConnectivity def prepare_arg_parser(): @@ -61,8 +62,9 @@ def init_from_checkpoint(args, checkpoint_path): dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader - _args = argparse.Namespace(**checkpoint_state['batch_loader_params']) - batch_loader = prepare_batch_loader(dataset, model, _args, sub_loggers_level) + batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint( + dataset, model, checkpoint_state['batch_loader_params'], + sub_loggers_level) # Instantiate trainer with Timer("\nPreparing trainer", newline=True, color='red'): diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index 047b6398..a5cd91ea 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -16,10 +16,10 @@ OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.transformer_trainer import TransformerTrainer -from dwi_ml.training.utils.batch_samplers import prepare_batch_sampler -from dwi_ml.training.utils.batch_loaders import prepare_batch_loader from dwi_ml.training.utils.experiment import add_args_resuming_experiment from dwi_ml.training.utils.trainer import run_experiment +from dwi_ml.training.with_generation.batch_loader import \ + DWIMLBatchLoaderWithConnectivity def prepare_arg_parser(): @@ -73,8 +73,9 @@ def init_from_checkpoint(args, checkpoint_path): dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader - _args = argparse.Namespace(**checkpoint_state['batch_loader_params']) - batch_loader = prepare_batch_loader(dataset, model, _args, sub_loggers_level) + batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint( + dataset, model, checkpoint_state['batch_loader_params'], + sub_loggers_level) # Instantiate trainer with Timer("\n\nPreparing trainer", newline=True, color='red'): From a3a2c32b6b1c32ce5b69465c4e7c75b66fb05605 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 26 Sep 2023 11:36:49 -0400 Subject: [PATCH 3/4] Add noise_loss param for deprecated batch loaders --- dwi_ml/training/batch_loaders.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 0173caaf..2161b410 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -161,6 +161,12 @@ def params_for_checkpoint(self): @classmethod def init_from_checkpoint(cls, dataset, model, checkpoint_state, new_log_level): + # Adding noise_gaussian_size_loss for deprecated batch loaders + if 'noise_gaussian_size_loss' not in checkpoint_state: + logging.warning("Deprecated batch loader. Did not contain a " + "noise_gaussian_size_loss value. Setting to 0.0.") + checkpoint_state['noise_gaussian_size_loss'] = 0.0 + batch_loader = cls(dataset=dataset, model=model, log_level=new_log_level, **checkpoint_state) return batch_loader From c5324dedd38cb69e1d9ffb277a7c0e314aadd169 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 26 Sep 2023 17:22:25 -0400 Subject: [PATCH 4/4] sneek in small fix in setup file --- dwi_ml/version.py | 3 ++- setup.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dwi_ml/version.py b/dwi_ml/version.py index 6a0001d4..4ed0827f 100644 --- a/dwi_ml/version.py +++ b/dwi_ml/version.py @@ -82,6 +82,7 @@ MINOR = _version_minor MICRO = _version_micro VERSION = __version__ -SCRIPTS = glob.glob("scripts_python/*.py") + glob.glob("bash_utilities/*.sh") + glob.glob("scripts_python/tests/scripts_on_test_model/*.py") +PYTHON_SCRIPTS = glob.glob("scripts_python/*.py") +BASH_SCRIPTS = glob.glob("bash_utilities/*.sh") PREVIOUS_MAINTAINERS=[] diff --git a/setup.py b/setup.py index 463ad088..192947b5 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,9 @@ entry_points={ 'console_scripts': ["{}=scripts_python.{}:main".format( os.path.basename(s), - os.path.basename(s).split(".")[0]) for s in SCRIPTS] + os.path.basename(s).split(".")[0]) for s in PYTHON_SCRIPTS] }, + scripts=[s for s in BASH_SCRIPTS], data_files=[], include_package_data=True)