Skip to content

Commit

Permalink
Get ckpt type if ckpt is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Jun 27, 2024
1 parent 9d41ed4 commit 4668b58
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 40 deletions.
39 changes: 16 additions & 23 deletions crabs/detection_tracking/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,22 @@ def slurm_logs_as_artifacts(logger, slurm_job_id):


def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
"""Get checkpoint type (full or weights) from the checkpoint path.
If checkpoint_path is None, the checkpoint type is None.
"""
# Get checkpoint type
if checkpoint_path:
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {checkpoint_path}"
)
"""Get checkpoint type (full or weights) from the checkpoint path."""
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {checkpoint_path}"
)
else:
checkpoint_type = None
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {checkpoint_path}"
)

return checkpoint_type
33 changes: 16 additions & 17 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,35 +165,34 @@ def core_training(self) -> lightning.Trainer:
self.seed_n,
)

# Get checkpoint type: "full", "weights" or None
checkpoint_type = get_checkpoint_type(self.checkpoint_path)

# Get model
if checkpoint_type == "weights":
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config, # overwrite hparams from ckpt with config
optuna_log=self.args.optuna,
)
else:
if not self.checkpoint_path:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
)
else:
checkpoint_type = get_checkpoint_type(self.checkpoint_path)
if checkpoint_type == "weights":
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config, # overwrite hparams from ckpt with config
optuna_log=self.args.optuna,
)
# a 'weights' checkpoint is one saved with `save_weights_only=True`

# Get trainer
trainer = self.setup_trainer()

# Run training
# Resume from full checkpoint if available
# (automatically restores model, epoch, step, LR schedulers, etc...)
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path
if checkpoint_type == "full"
else None,
# needs to having been saved with `save_weights_only=False`
ckpt_path=(
self.checkpoint_path if checkpoint_type == "full" else None
),
# a 'full' checkpoint is one saved with `save_weights_only=False`
# (automatically restores model, epoch, step, LR schedulers, etc...)
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
)

return trainer
Expand Down

0 comments on commit 4668b58

Please sign in to comment.