Skip to content

Commit

Permalink
Small refactoring to load from ckpt (#203)
Browse files Browse the repository at this point in the history
* Move checkpoint type computation to utils

* Refactor checkpointing in training script

* Get ckpt type if ckpt is passed
  • Loading branch information
sfmig authored Jun 28, 2024
1 parent 81db31e commit e247e89
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions crabs/detection_tracking/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,6 @@ def slurm_logs_as_artifacts(logger, slurm_job_id):
)


def log_data_augm_as_artifacts(logger, data_module):
"""Log data augmentation transforms as artifacts in MLflow."""
for transform_str in ["train_transform", "test_val_transform"]:
logger.experiment.log_text(
text=str(getattr(data_module, f"_get_{transform_str}")()),
artifact_file=f"{transform_str}.txt",
run_id=logger.run_id,
)


def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
"""Get checkpoint type (full or weights) from the checkpoint path."""
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
Expand All @@ -274,3 +264,13 @@ def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
)

return checkpoint_type


def log_data_augm_as_artifacts(logger, data_module):
"""Log data augmentation transforms as artifacts in MLflow."""
for transform_str in ["train_transform", "test_val_transform"]:
logger.experiment.log_text(
text=str(getattr(data_module, f"_get_{transform_str}")()),
artifact_file=f"{transform_str}.txt",
run_id=logger.run_id,
)

0 comments on commit e247e89

Please sign in to comment.