Skip to content

Commit

Permalink
Merge pull request #205 from EmmaRenauld/small_fix_loss
Browse files Browse the repository at this point in the history
Managing checkpoint officially for batch_loader and batch_sampler
  • Loading branch information
EmmaRenauld authored Sep 27, 2023
2 parents 0fbd2c7 + c5324de commit ceab621
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 16 deletions.
13 changes: 13 additions & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def params_for_checkpoint(self):
}
return params

@classmethod
def init_from_checkpoint(cls, dataset, model, checkpoint_state,
new_log_level):
# Adding noise_gaussian_size_loss for deprecated batch loaders
if 'noise_gaussian_size_loss' not in checkpoint_state:
logging.warning("Deprecated batch loader. Did not contain a "
"noise_gaussian_size_loss value. Setting to 0.0.")
checkpoint_state['noise_gaussian_size_loss'] = 0.0

batch_loader = cls(dataset=dataset, model=model,
log_level=new_log_level, **checkpoint_state)
return batch_loader

def set_context(self, context: str):
if self.context != context:
if context == 'training':
Expand Down
12 changes: 12 additions & 0 deletions dwi_ml/training/batch_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.utils.data import Sampler

from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset
from dwi_ml.experiment_utils.prints import format_dict_to_str

DEFAULT_CHUNK_SIZE = 256
logger = logging.getLogger('batch_sampler_logger')
Expand Down Expand Up @@ -156,6 +157,17 @@ def params_for_checkpoint(self):
}
return params

@classmethod
def init_from_checkpoint(cls, dataset, checkpoint_state: dict,
new_log_level):
batch_sampler = cls(dataset=dataset, log_level=new_log_level,
**checkpoint_state)

logging.info("Batch sampler's user-defined parameters: " +
format_dict_to_str(batch_sampler.params_for_checkpoint))

return batch_sampler

def set_context(self, context):
if self.context != context:
if context == 'training':
Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/training/utils/batch_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def prepare_batch_sampler(dataset, args, sub_loggers_level):
cycles=args.cycles,
rng=args.rng, log_level=sub_loggers_level)

logging.info("Batch sampler's user-defined parameters: " +
format_dict_to_str(batch_sampler.params_for_checkpoint))


return batch_sampler
3 changes: 2 additions & 1 deletion dwi_ml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
MINOR = _version_minor
MICRO = _version_micro
VERSION = __version__
SCRIPTS = glob.glob("scripts_python/*.py") + glob.glob("bash_utilities/*.sh") + glob.glob("scripts_python/tests/scripts_on_test_model/*.py")
PYTHON_SCRIPTS = glob.glob("scripts_python/*.py")
BASH_SCRIPTS = glob.glob("bash_utilities/*.sh")

PREVIOUS_MAINTAINERS=[]
14 changes: 8 additions & 6 deletions scripts_python/l2t_resume_training_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.io_utils import add_logging_arg
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer
from dwi_ml.training.utils.batch_loaders import prepare_batch_loader
from dwi_ml.training.utils.batch_samplers import prepare_batch_sampler
from dwi_ml.training.utils.experiment import add_args_resuming_experiment
from dwi_ml.training.utils.trainer import run_experiment
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity


def prepare_arg_parser():
Expand Down Expand Up @@ -57,12 +58,13 @@ def init_from_checkpoint(args, checkpoint_path):
os.path.join(checkpoint_path, 'model'), sub_loggers_level)

# Prepare batch sampler
_args = argparse.Namespace(**checkpoint_state['batch_sampler_params'])
batch_sampler = prepare_batch_sampler(dataset, _args, sub_loggers_level)
batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint(
dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level)

# Prepare batch loader
_args = argparse.Namespace(**checkpoint_state['batch_loader_params'])
batch_loader = prepare_batch_loader(dataset, model, _args, sub_loggers_level)
batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint(
dataset, model, checkpoint_state['batch_loader_params'],
sub_loggers_level)

# Instantiate trainer
with Timer("\nPreparing trainer", newline=True, color='red'):
Expand Down
14 changes: 8 additions & 6 deletions scripts_python/tt_resume_training_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path
from dwi_ml.models.projects.transformer_models import \
OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.projects.transformer_trainer import TransformerTrainer
from dwi_ml.training.utils.batch_samplers import prepare_batch_sampler
from dwi_ml.training.utils.batch_loaders import prepare_batch_loader
from dwi_ml.training.utils.experiment import add_args_resuming_experiment
from dwi_ml.training.utils.trainer import run_experiment
from dwi_ml.training.with_generation.batch_loader import \
DWIMLBatchLoaderWithConnectivity


def prepare_arg_parser():
Expand Down Expand Up @@ -68,12 +69,13 @@ def init_from_checkpoint(args, checkpoint_path):
model = cls.load_model_from_params_and_state(model_dir, sub_loggers_level)

# Prepare batch sampler
_args = argparse.Namespace(**checkpoint_state['batch_sampler_params'])
batch_sampler = prepare_batch_sampler(dataset, _args, sub_loggers_level)
batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint(
dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level)

# Prepare batch loader
_args = argparse.Namespace(**checkpoint_state['batch_loader_params'])
batch_loader = prepare_batch_loader(dataset, model, _args, sub_loggers_level)
batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint(
dataset, model, checkpoint_state['batch_loader_params'],
sub_loggers_level)

# Instantiate trainer
with Timer("\n\nPreparing trainer", newline=True, color='red'):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
entry_points={
'console_scripts': ["{}=scripts_python.{}:main".format(
os.path.basename(s),
os.path.basename(s).split(".")[0]) for s in SCRIPTS]
os.path.basename(s).split(".")[0]) for s in PYTHON_SCRIPTS]
},
scripts=[s for s in BASH_SCRIPTS],
data_files=[],
include_package_data=True)

Expand Down

0 comments on commit ceab621

Please sign in to comment.