Skip to content

Commit

Permalink
Small edits to training (#235)
Browse files Browse the repository at this point in the history
* Move run name setting to constructor

* Add logs to screen on mlflow metadata and on dataset

* Fix typo in detector class

* Reorder and clarify mlruns folder help in CLI

* Small edits to evaluate script for consistency

* Additions for consistency with evaluation script

* Factor out logging to screen

* Remove refactor comments
  • Loading branch information
sfmig authored Nov 5, 2024
1 parent c992a79 commit a98b2a2
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 64 deletions.
24 changes: 9 additions & 15 deletions crabs/detector/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from crabs.detector.datamodules import CrabsDataModule
from crabs.detector.models import FasterRCNN
from crabs.detector.utils.detection import (
log_dataset_metadata_as_info,
log_mlflow_metadata_as_info,
set_mlflow_run_name,
setup_mlflow_logger,
slurm_logs_as_artifacts,
Expand All @@ -26,8 +28,6 @@
)
from crabs.detector.utils.visualization import save_images_with_boxes

logging.getLogger().setLevel(logging.INFO)


class DetectorEvaluate:
"""Interface for evaluating an object detector.
Expand Down Expand Up @@ -90,18 +90,10 @@ def __init__(self, args: argparse.Namespace) -> None:
self.limit_test_batches = args.limit_test_batches

# Log dataset information to screen
logging.info("Dataset")
logging.info(f"Images directories: {self.images_dirs}")
logging.info(f"Annotation files: {self.annotation_files}")
logging.info(f"Seed: {self.seed_n}")
logging.info("---------------------------------")
log_dataset_metadata_as_info(self)

# Log MLflow information to screen
logging.info("MLflow logs for current job")
logging.info(f"Experiment name: {self.experiment_name}")
logging.info(f"Run name: {self.run_name}")
logging.info(f"Folder: {Path(self.mlflow_folder).resolve()}")
logging.info("---------------------------------")
log_mlflow_metadata_as_info(self)

def setup_trainer(self):
"""Set up trainer object with logging for testing."""
Expand Down Expand Up @@ -195,12 +187,12 @@ def evaluate_model(self) -> None:


def main(args) -> None:
"""Run detector testing.
"""Run detector evaluation.
Parameters
----------
args : argparse
Arguments or configuration settings for testing.
args : argparse.Namespace
An object containing the parsed command-line arguments.
Returns
-------
Expand Down Expand Up @@ -346,6 +338,8 @@ def evaluate_parse_args(args):

def app_wrapper():
"""Wrap function to run the evaluation."""
logging.getLogger().setLevel(logging.INFO)

torch.set_float32_matmul_precision("medium")

eval_args = evaluate_parse_args(sys.argv[1:])
Expand Down
104 changes: 64 additions & 40 deletions crabs/detector/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Train FasterRCNN model for object detection."""

import argparse
import logging
import os
import sys
from pathlib import Path
Expand All @@ -14,6 +15,8 @@
from crabs.detector.datamodules import CrabsDataModule
from crabs.detector.models import FasterRCNN
from crabs.detector.utils.detection import (
log_dataset_metadata_as_info,
log_mlflow_metadata_as_info,
prep_annotation_files,
prep_img_directories,
set_mlflow_run_name,
Expand All @@ -27,7 +30,7 @@
)


class DectectorTrain:
class DetectorTrain:
"""Training class for detector algorithm.
Parameters
Expand Down Expand Up @@ -56,6 +59,7 @@ def __init__(self, args: argparse.Namespace):

# MLflow
self.experiment_name = args.experiment_name
self.run_name = set_mlflow_run_name()
self.mlflow_folder = args.mlflow_folder

# Debugging
Expand All @@ -65,15 +69,19 @@ def __init__(self, args: argparse.Namespace):
# Restart from checkpoint
self.checkpoint_path = args.checkpoint_path

# Log dataset information to screen
log_dataset_metadata_as_info(self)

# Log MLflow information to screen
log_mlflow_metadata_as_info(self)

def load_config_yaml(self):
"""Load yaml file that contains config parameters."""
with open(self.config_file) as f:
self.config = yaml.safe_load(f)

def setup_trainer(self):
"""Set up trainer with logging and checkpointing."""
self.run_name = set_mlflow_run_name()

# Setup logger with checkpointing
mlf_logger = setup_mlflow_logger(
experiment_name=self.experiment_name,
Expand All @@ -84,6 +92,15 @@ def setup_trainer(self):
# pass the checkpointing config if defined
)

# Add dataset section to MLflow hyperparameters
mlf_logger.log_hyperparams(
{
"dataset/images_dir": self.images_dirs,
"dataset/annotation_files": self.annotation_files,
"dataset/seed": self.seed_n,
}
)

# Define checkpointing callback for trainer
config_ckpt = self.config.get("checkpoint_saving")
if config_ckpt:
Expand Down Expand Up @@ -117,7 +134,7 @@ def setup_trainer(self):
def optuna_objective_fn(self, trial: optuna.Trial) -> float:
"""Objective function for Optuna.
When used with Optuna, it wil maximise precision and recall on the
When used with Optuna, it will maximise precision and recall on the
validation set.
Parameters
Expand Down Expand Up @@ -176,6 +193,7 @@ def core_training(self) -> lightning.Trainer:
)

# Get model
# (from a previous checkpoint if required)
if not self.checkpoint_path:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
Expand Down Expand Up @@ -231,12 +249,13 @@ def train_model(self):

# if this is a slurm job: add slurm logs as artifacts
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id:
slurm_job_name = os.environ.get("SLURM_JOB_NAME")
if slurm_job_id and (slurm_job_name != "bash"):
slurm_logs_as_artifacts(trainer.logger, slurm_job_id)


def main(args) -> None:
"""Run training process.
"""Run detector training.
Parameters
----------
Expand All @@ -248,7 +267,7 @@ def main(args) -> None:
None
"""
trainer = DectectorTrain(args)
trainer = DetectorTrain(args)
trainer.train_model()


Expand All @@ -272,6 +291,12 @@ def train_parse_args(args):
"dataset/annotations."
),
)
parser.add_argument(
"--seed_n",
type=int,
default=42,
help="Seed for dataset splits. Default: 42",
)
parser.add_argument(
"--config_file",
type=str,
Expand All @@ -282,6 +307,12 @@ def train_parse_args(args):
"crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml"
),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help=("Path to checkpoint to resume training. " "Default: None."),
)
parser.add_argument(
"--accelerator",
type=str,
Expand All @@ -306,37 +337,26 @@ def train_parse_args(args):
),
)
parser.add_argument(
"--seed_n",
type=int,
default=42,
help="Seed for dataset splits. Default: 42",
"--mlflow_folder",
type=str,
default="./ml-runs",
help=(
"Path to MLflow directory where to log the training data. "
"Default: 'ml-runs' directory under the current working directory."
),
)
parser.add_argument(
"--fast_dev_run",
"--no_data_augmentation",
action="store_true",
help="Debugging option to run training for one batch and one epoch",
)
parser.add_argument(
"--limit_train_batches",
type=float,
default=1.0,
help=(
"Debugging option to run training on a fraction of the "
"training set. "
"Default: 1.0 (the full training set)"
"Ignore the data augmentation transforms "
"defined in the config file"
),
)
parser.add_argument(
"--mlflow_folder",
type=str,
default="./ml-runs",
help=("Path to MLflow directory. Default: ./ml-runs"),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help=("Path to checkpoint for resume training"),
"--log_data_augmentation",
action="store_true",
help=("Log data augmentation transforms to " "MLflow as artifacts"),
)
parser.add_argument(
"--optuna",
Expand All @@ -347,23 +367,27 @@ def train_parse_args(args):
),
)
parser.add_argument(
"--no_data_augmentation",
"--fast_dev_run",
action="store_true",
help=(
"Ignore the data augmentation transforms "
"defined in the config file"
),
help="Debugging option to run training for one batch and one epoch",
)
parser.add_argument(
"--log_data_augmentation",
action="store_true",
help=("Log data augmentation transforms to " "MLflow as artifacts"),
"--limit_train_batches",
type=float,
default=1.0,
help=(
"Debugging option to run training on a fraction of the "
"training set. "
"Default: 1.0 (the full training set)"
),
)
return parser.parse_args(args)


def app_wrapper():
"""Wrap function to run the training application."""
"""Wrap function to run the training."""
logging.getLogger().setLevel(logging.INFO)

torch.set_float32_matmul_precision("medium")

train_args = train_parse_args(sys.argv[1:])
Expand Down
19 changes: 19 additions & 0 deletions crabs/detector/utils/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import datetime
import logging
import os
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -315,3 +316,21 @@ def bbox_tensors_to_COCO_dict(
}

return coco_dict


def log_dataset_metadata_as_info(detector_interface):
"""Print dataset metadata as logging info."""
logging.info("Dataset")
logging.info(f"Images directories: {detector_interface.images_dirs}")
logging.info(f"Annotation files: {detector_interface.annotation_files}")
logging.info(f"Seed: {detector_interface.seed_n}")
logging.info("---------------------------------")


def log_mlflow_metadata_as_info(detector_interface):
"""Print MLflow metadata as logging info."""
logging.info("MLflow logs for current job")
logging.info(f"Experiment name: {detector_interface.experiment_name}")
logging.info(f"Run name: {detector_interface.run_name}")
logging.info(f"Folder: {Path(detector_interface.mlflow_folder).resolve()}")
logging.info("---------------------------------")
6 changes: 3 additions & 3 deletions tests/test_unit/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain
from crabs.detector.utils.hpo import compute_optimal_hyperparameters


Expand Down Expand Up @@ -37,8 +37,8 @@ def args():

@pytest.fixture
def detector_train(args, config):
with patch.object(DectectorTrain, "load_config_yaml", MagicMock()):
train_instance = DectectorTrain(args=args)
with patch.object(DetectorTrain, "load_config_yaml", MagicMock()):
train_instance = DetectorTrain(args=args)
print(config)
train_instance.config = config
return train_instance
Expand Down
12 changes: 6 additions & 6 deletions tests/test_unit/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
)
def test_prep_img_directories(dataset_dirs: list):
"""Test parsing of image directories when training a model."""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare parser
train_args = train_parse_args(["--dataset_dirs"] + dataset_dirs)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check image directories are parsed correctly
list_imgs_dirs = [str(Path(d) / "frames") for d in dataset_dirs]
Expand All @@ -47,7 +47,7 @@ def test_prep_annotation_files_single_dataset(annotation_files, expected):
"""Test parsing of annotation files when training a model on a single
dataset.
"""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare CLI arguments
cli_inputs = ["--dataset_dirs", DATASET_1]
Expand All @@ -59,7 +59,7 @@ def test_prep_annotation_files_single_dataset(annotation_files, expected):
train_args = train_parse_args(cli_inputs + annotation_files)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check annotation files are as expected
assert detector.annotation_files == expected
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_prep_annotation_files_multiple_datasets(annotation_files, expected):
"""Test parsing of annotation files when training
a model on two datasets.
"""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare CLI arguments considering multiple dataset
cli_inputs = ["--dataset_dirs", DATASET_1, DATASET_2]
Expand All @@ -101,7 +101,7 @@ def test_prep_annotation_files_multiple_datasets(annotation_files, expected):
train_args = train_parse_args(cli_inputs + annotation_files)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check annotation files are as expected
assert detector.annotation_files == expected

0 comments on commit a98b2a2

Please sign in to comment.