diff --git a/crabs/detection_tracking/detection_utils.py b/crabs/detection_tracking/detection_utils.py index f86f8c71..4618b72d 100644 --- a/crabs/detection_tracking/detection_utils.py +++ b/crabs/detection_tracking/detection_utils.py @@ -244,16 +244,6 @@ def slurm_logs_as_artifacts(logger, slurm_job_id): ) -def log_data_augm_as_artifacts(logger, data_module): - """Log data augmentation transforms as artifacts in MLflow.""" - for transform_str in ["train_transform", "test_val_transform"]: - logger.experiment.log_text( - text=str(getattr(data_module, f"_get_{transform_str}")()), - artifact_file=f"{transform_str}.txt", - run_id=logger.run_id, - ) - - def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]: """Get checkpoint type (full or weights) from the checkpoint path.""" checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist @@ -274,3 +264,13 @@ def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]: ) return checkpoint_type + + +def log_data_augm_as_artifacts(logger, data_module): + """Log data augmentation transforms as artifacts in MLflow.""" + for transform_str in ["train_transform", "test_val_transform"]: + logger.experiment.log_text( + text=str(getattr(data_module, f"_get_{transform_str}")()), + artifact_file=f"{transform_str}.txt", + run_id=logger.run_id, + )