From e18357e44d9c069ea617d3eaf41da3d2389812dd Mon Sep 17 00:00:00 2001 From: nikk-nikaznan <48319650+nikk-nikaznan@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:26:32 +0100 Subject: [PATCH] Fix missing line (#200) * adding the lines back * adding the optuna arg * Add optuna flaf if resuming from weights only --------- Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> --- crabs/detection_tracking/train_model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crabs/detection_tracking/train_model.py b/crabs/detection_tracking/train_model.py index 67dd5026..3a4bc967 100644 --- a/crabs/detection_tracking/train_model.py +++ b/crabs/detection_tracking/train_model.py @@ -186,7 +186,23 @@ def core_training(self) -> lightning.Trainer: checkpoint_type = None # Get model - lightning_model = FasterRCNN(self.config, optuna_log=self.args.optuna) + if checkpoint_type == "weights": + # Note: weights-only checkpoint contains hyperparameters + # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters + lightning_model = FasterRCNN.load_from_checkpoint( + self.checkpoint_path, + config=self.config, + optuna_log=self.args.optuna, + # overwrite checkpoint hyperparameters with config ones + # otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used + ) + else: + lightning_model = FasterRCNN( + self.config, optuna_log=self.args.optuna + ) + + # Get trainer + trainer = self.setup_trainer() # Get trainer trainer = self.setup_trainer()