Skip to content

Commit

Permalink
Fix missing line (#200)
Browse files Browse the repository at this point in the history
* adding the lines back

* adding the optuna arg

* Add optuna flaf if resuming from weights only

---------

Co-authored-by: sfmig <[email protected]>
  • Loading branch information
nikk-nikaznan and sfmig authored Jun 27, 2024
1 parent 2bf7705 commit e18357e
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,23 @@ def core_training(self) -> lightning.Trainer:
checkpoint_type = None

# Get model
lightning_model = FasterRCNN(self.config, optuna_log=self.args.optuna)
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
optuna_log=self.args.optuna,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
)

# Get trainer
trainer = self.setup_trainer()

# Get trainer
trainer = self.setup_trainer()
Expand Down

0 comments on commit e18357e

Please sign in to comment.