diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index cbdfaad3..66893684 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -37,16 +37,29 @@ batch_size_test: 4 # ------------------- # Data augmentation # ------------------- -transform_brightness: 0.5 -transform_hue: 0.3 -gaussian_blur_params: +gaussian_blur: kernel_size: - 5 - 9 sigma: - 0.1 - 5.0 - +color_jitter: + brightness: 0.5 + hue: 0.3 +random_horizontal_flip: + p: 0.5 +random_rotation: + degrees: [-10.0, 10.0] +random_adjust_sharpness: + p: 0.5 + sharpness_factor: 0.5 +random_autocontrast: + p: 0.5 +random_equalize: + p: 0.5 +clamp_and_sanitize_bboxes: + min_size: 1.0 # ---------------------------- # Hyperparameter optimisation # ----------------------------- diff --git a/crabs/detection_tracking/datamodules.py b/crabs/detection_tracking/datamodules.py index 6e882a43..d2153169 100644 --- a/crabs/detection_tracking/datamodules.py +++ b/crabs/detection_tracking/datamodules.py @@ -24,12 +24,65 @@ def __init__( list_annotation_files: list[str], config: dict, split_seed: Optional[int] = None, + no_data_augmentation: bool = False, ): super().__init__() self.list_img_dirs = list_img_dirs self.list_annotation_files = list_annotation_files self.split_seed = split_seed self.config = config + self.no_data_augmentation = no_data_augmentation + + def _transform_str_to_operator(self, transform_str): + """Get transform operator from its name in snake case""" + + def snake_to_camel_case(snake_str): + return "".join( + x.capitalize() for x in snake_str.lower().split("_") + ) + + transform_callable = getattr( + transforms, snake_to_camel_case(transform_str) + ) + + return transform_callable(**self.config[transform_str]) + + def _compute_list_of_transforms(self) -> list[torchvision.transforms.v2]: + """Read transforms from config and add to list""" + + # Initialise list + train_data_augm: list[torchvision.transforms.v2] = [] + + # Apply standard transforms if defined + for transform_str in [ + "gaussian_blur", + "color_jitter", + "random_horizontal_flip", + "random_rotation", + "random_adjust_sharpness", + "random_autocontrast", + "random_equalize", + ]: + if transform_str in self.config: + transform_operator = self._transform_str_to_operator( + transform_str + ) + train_data_augm.append(transform_operator) + + # Apply clamp and sanitize bboxes if defined + # See https://pytorch.org/vision/main/generated/torchvision.transforms.v2.SanitizeBoundingBoxes.html#torchvision.transforms.v2.SanitizeBoundingBoxes + if "clamp_and_sanitize_bboxes" in self.config: + # Clamp bounding boxes + train_data_augm.append(transforms.ClampBoundingBoxes()) + + # Sanitize + sanitize = transforms.SanitizeBoundingBoxes( + min_size=self.config["clamp_and_sanitize_bboxes"]["min_size"], + labels_getter=None, # only bboxes are sanitized + ) + train_data_augm.append(sanitize) + + return train_data_augm def _get_train_transform(self) -> torchvision.transforms: """Define data augmentation transforms for the train set. @@ -38,17 +91,22 @@ def _get_train_transform(self) -> torchvision.transforms: https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_e2e.html#transforms + ToDtype is the recommended replacement for ConvertImageDtype(dtype) + https://pytorch.org/vision/0.17/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype + """ - jitter = transforms.ColorJitter( - brightness=self.config["transform_brightness"], - hue=self.config["transform_hue"], - ) - gauss = transforms.GaussianBlur( - kernel_size=self.config["gaussian_blur_params"]["kernel_size"], - sigma=self.config["gaussian_blur_params"]["sigma"], - ) - todtype = transforms.ToDtype(torch.float32, scale=True) - train_transforms = [transforms.ToImage(), jitter, gauss, todtype] + # Compute list of transforms to apply + if self.no_data_augmentation: + train_data_augm = [] + else: + train_data_augm = self._compute_list_of_transforms() + + # Define a Compose transform with them + train_transforms = [ + transforms.ToImage(), + *train_data_augm, + transforms.ToDtype(torch.float32, scale=True), + ] return transforms.Compose(train_transforms) def _get_test_val_transform(self) -> torchvision.transforms: diff --git a/crabs/detection_tracking/detection_utils.py b/crabs/detection_tracking/detection_utils.py index 12ef1eff..f86f8c71 100644 --- a/crabs/detection_tracking/detection_utils.py +++ b/crabs/detection_tracking/detection_utils.py @@ -1,9 +1,11 @@ import argparse import datetime +import logging import os from pathlib import Path -from typing import Any +from typing import Any, Optional +import torch from lightning.pytorch.loggers import MLFlowLogger DEFAULT_ANNOTATIONS_FILENAME = "VIA_JSON_combined_coco_gen.json" @@ -240,3 +242,35 @@ def slurm_logs_as_artifacts(logger, slurm_job_id): logger.run_id, f"{log_filename}.{ext}", ) + + +def log_data_augm_as_artifacts(logger, data_module): + """Log data augmentation transforms as artifacts in MLflow.""" + for transform_str in ["train_transform", "test_val_transform"]: + logger.experiment.log_text( + text=str(getattr(data_module, f"_get_{transform_str}")()), + artifact_file=f"{transform_str}.txt", + run_id=logger.run_id, + ) + + +def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]: + """Get checkpoint type (full or weights) from the checkpoint path.""" + checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist + if all( + [ + param in checkpoint + for param in ["optimizer_states", "lr_schedulers"] + ] + ): + checkpoint_type = "full" # for resuming training + logging.info( + f"Resuming training from checkpoint at: {checkpoint_path}" + ) + else: + checkpoint_type = "weights" # for fine tuning + logging.info( + f"Fine-tuning training from checkpoint at: {checkpoint_path}" + ) + + return checkpoint_type diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 069c9ee9..96fee9ab 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -106,10 +106,10 @@ def evaluate_model(self) -> None: """ # Create datamodule data_module = CrabsDataModule( - self.images_dirs, - self.annotation_files, - self.config, - self.seed_n, + list_img_dirs=self.images_dirs, + list_annotation_files=self.annotation_files, + split_seed=self.seed_n, + config=self.config, ) # Get trained model diff --git a/crabs/detection_tracking/train_model.py b/crabs/detection_tracking/train_model.py index a88389f8..400d7ba2 100644 --- a/crabs/detection_tracking/train_model.py +++ b/crabs/detection_tracking/train_model.py @@ -1,5 +1,4 @@ import argparse -import logging import os import sys from pathlib import Path @@ -12,6 +11,8 @@ from crabs.detection_tracking.datamodules import CrabsDataModule from crabs.detection_tracking.detection_utils import ( + get_checkpoint_type, + log_data_augm_as_artifacts, prep_annotation_files, prep_img_directories, set_mlflow_run_name, @@ -57,6 +58,7 @@ def __init__(self, args): self.fast_dev_run = args.fast_dev_run self.limit_train_batches = args.limit_train_batches + # Restart from checkpoint self.checkpoint_path = args.checkpoint_path def load_config_yaml(self): @@ -158,67 +160,44 @@ def core_training(self) -> lightning.Trainer: """ # Create data module data_module = CrabsDataModule( - self.images_dirs, - self.annotation_files, - self.config, - self.seed_n, + list_img_dirs=self.images_dirs, + list_annotation_files=self.annotation_files, + split_seed=self.seed_n, + config=self.config, + no_data_augmentation=self.args.no_data_augmentation, ) - # Get checkpoint type - if self.checkpoint_path and os.path.exists(self.checkpoint_path): - checkpoint = torch.load(self.checkpoint_path) - if all( - [ - param in checkpoint - for param in ["optimizer_states", "lr_schedulers"] - ] - ): - checkpoint_type = "full" # for resuming training - logging.info( - f"Resuming training from checkpoint at: {self.checkpoint_path}" - ) - else: - checkpoint_type = "weights" # for fine tuning - logging.info( - f"Fine-tuning training from checkpoint at: {self.checkpoint_path}" - ) - else: - checkpoint_type = None - # Get model - if checkpoint_type == "weights": - # Note: weights-only checkpoint contains hyperparameters - # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters - lightning_model = FasterRCNN.load_from_checkpoint( - self.checkpoint_path, - config=self.config, - optuna_log=self.args.optuna, - # overwrite checkpoint hyperparameters with config ones - # otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used - ) - else: + if not self.checkpoint_path: lightning_model = FasterRCNN( self.config, optuna_log=self.args.optuna ) + checkpoint_type = None + else: + checkpoint_type = get_checkpoint_type(self.checkpoint_path) + if checkpoint_type == "weights": + lightning_model = FasterRCNN.load_from_checkpoint( + self.checkpoint_path, + config=self.config, # overwrite hparams from ckpt with config + optuna_log=self.args.optuna, + ) # a 'weights' checkpoint is one saved with `save_weights_only=True` # Get trainer trainer = self.setup_trainer() + if self.args.log_data_augmentation: + log_data_augm_as_artifacts(trainer.logger, data_module) # Run training - # Resume from full checkpoint if available - # (automatically restores model, epoch, step, LR schedulers, etc...) - # https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters - if checkpoint_type == "full": - trainer.fit( - lightning_model, - data_module, - ckpt_path=self.checkpoint_path, # needs to having been saved with `save_weights_only=False` - ) - else: # for "weights" or no checkpoint - trainer.fit( - lightning_model, - data_module, - ) + trainer.fit( + lightning_model, + data_module, + ckpt_path=( + self.checkpoint_path if checkpoint_type == "full" else None + ), + # a 'full' checkpoint is one saved with `save_weights_only=False` + # (automatically restores model, epoch, step, LR schedulers, etc...) + # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters + ) return trainer @@ -344,6 +323,16 @@ def train_parse_args(args): action="store_true", help="Run a hyperparameter optimisation using Optuna prior to training the model", ) + parser.add_argument( + "--no_data_augmentation", + action="store_true", + help="Ignore the data augmentation transforms defined in config file", + ) + parser.add_argument( + "--log_data_augmentation", + action="store_true", + help="Log data augmentation transforms linked to datamodule as MLflow artifacts", + ) return parser.parse_args(args) diff --git a/notebooks/notebook_data_augm.py b/notebooks/notebook_data_augm.py new file mode 100644 index 00000000..d8be52a1 --- /dev/null +++ b/notebooks/notebook_data_augm.py @@ -0,0 +1,45 @@ +# %% +import yaml # type: ignore + +from crabs.detection_tracking.datamodules import CrabsDataModule +from crabs.detection_tracking.visualization import plot_sample + +# %%%%%%%%%%%%%%%%%%% +# Input data +IMG_DIR = "/home/sminano/swc/project_crabs/data/sep2023-full/frames" +ANNOT_FILE = "/home/sminano/swc/project_crabs/data/sep2023-full/annotations/VIA_JSON_combined_coco_gen.json" +CONFIG = "/home/sminano/swc/project_crabs/crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" +SPLIT_SEED = 42 + +# %%%%%%%%%%%%%%%%%%%% +# Read config as dict +with open(CONFIG, "r") as f: + config_dict = yaml.safe_load(f) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Create datamodule for the input data +dm = CrabsDataModule( + list_img_dirs=[IMG_DIR], + list_annotation_files=[ANNOT_FILE], + config=config_dict, + split_seed=SPLIT_SEED, +) +# %%%%%%%%%%%%%%%%%%%%%%%% +# Setup for train / test +dm.prepare_data() +dm.setup("fit") + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%% +# after this: dm.train_dataset should have transforms, (but not dm.test_dataset) +print(dm.train_transform) +print(dm.val_transform) +print(dm.test_transform) + +# %%%%%%%%%%%%%%%%%%%%%%%%% +# visualize +train_dataset = dm.train_dataset +train_sample = train_dataset[0] +plot_sample([train_sample]) + +# %% diff --git a/tests/test_unit/test_datamodules.py b/tests/test_unit/test_datamodules.py index 7a4478d2..02d9c76e 100644 --- a/tests/test_unit/test_datamodules.py +++ b/tests/test_unit/test_datamodules.py @@ -1,79 +1,170 @@ import random +from pathlib import Path import pytest import torch import torchvision.transforms.v2 as transforms +import yaml # type: ignore from crabs.detection_tracking.datamodules import CrabsDataModule +DEFAULT_CONFIG = ( + Path(__file__).parents[2] + / "crabs" + / "detection_tracking" + / "config" + / "faster_rcnn.yaml" +) + @pytest.fixture -def train_config(): - return { - "train_fraction": 0.8, - "val_over_test_fraction": 0, - "transform_brightness": 0.5, - "transform_hue": 0.3, - "gaussian_blur_params": {"kernel_size": [5, 9], "sigma": [0.1, 5.0]}, - } +def default_train_config(): + config_file = DEFAULT_CONFIG + with open(config_file, "r") as f: + return yaml.safe_load(f) + + +@pytest.fixture +def expected_data_augm_transforms(): + return transforms.Compose( + [ + transforms.ToImage(), + transforms.GaussianBlur(kernel_size=[5, 9], sigma=[0.1, 5.0]), + transforms.ColorJitter(brightness=(0.5, 1.5), hue=(-0.3, 0.3)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomRotation( + degrees=[-10.0, 10.0], + interpolation=transforms.InterpolationMode.NEAREST, + expand=False, + fill=0, + ), + transforms.RandomAdjustSharpness(p=0.5, sharpness_factor=0.5), + transforms.RandomAutocontrast(p=0.5), + transforms.RandomEqualize(p=0.5), + transforms.ClampBoundingBoxes(), + transforms.SanitizeBoundingBoxes(min_size=1.0, labels_getter=None), + transforms.ToDtype(torch.float32, scale=True), + ] + ) + + +@pytest.fixture +def expected_no_data_augm_transforms(): + return transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + ) @pytest.fixture -def crabs_data_module(train_config): +def crabs_data_module_with_data_augm(default_train_config): return CrabsDataModule( list_img_dirs=["dir1", "dir2"], list_annotation_files=["anno1", "anno2"], - config=train_config, + config=default_train_config, split_seed=123, + no_data_augmentation=False, ) @pytest.fixture -def expected_transforms_train_set(train_config): - return [ - transforms.ToImage(), - transforms.ColorJitter( - brightness=train_config["transform_brightness"], - hue=train_config["transform_hue"], - ), - transforms.GaussianBlur( - kernel_size=train_config["gaussian_blur_params"]["kernel_size"], - sigma=train_config["gaussian_blur_params"]["sigma"], - ), - transforms.ToDtype(torch.float32, scale=True), - ] +def crabs_data_module_without_data_augm(default_train_config): + return CrabsDataModule( + list_img_dirs=["dir1", "dir2"], + list_annotation_files=["anno1", "anno2"], + config=default_train_config, + split_seed=123, + no_data_augmentation=True, + ) + + +def compare_transforms_attrs_excluding(transform1, transform2, keys_to_skip): + """Compare the attributes of two transforms excluding those in list.""" + + transform1_attrs_without_fns = { + key: val + for key, val in transform1.__dict__.items() + if key not in keys_to_skip + } + + transform2_attrs_without_fns = { + key: val + for key, val in transform2.__dict__.items() + if key not in keys_to_skip + } + return transform1_attrs_without_fns == transform2_attrs_without_fns -def test_get_train_transform(crabs_data_module, expected_transforms_train_set): - train_transform = crabs_data_module._get_train_transform() - assert isinstance(train_transform, transforms.Compose) - assert len(train_transform.transforms) == len( - expected_transforms_train_set +@pytest.mark.parametrize( + "crabs_data_module, expected_train_transforms", + [ + ("crabs_data_module_with_data_augm", "expected_data_augm_transforms"), + ( + "crabs_data_module_without_data_augm", + "expected_no_data_augm_transforms", + ), + ], +) +def test_get_train_transform( + crabs_data_module, expected_train_transforms, request +): + crabs_data_module = request.getfixturevalue(crabs_data_module) + expected_train_transforms = request.getfixturevalue( + expected_train_transforms ) - for transform, expected_transform in zip( - train_transform.transforms, expected_transforms_train_set - ): - assert isinstance(transform, type(expected_transform)) + train_transforms = crabs_data_module._get_train_transform() -@pytest.fixture -def expected_transforms_test_set(): - return [ - transforms.ToImage(), - transforms.ToDtype(torch.float32, scale=True), - ] + assert isinstance(train_transforms, transforms.Compose) + + # assert all transforms in Compose have same attributes + for train_tr, expected_train_tr in zip( + train_transforms.transforms, + expected_train_transforms.transforms, + ): + # we skip the attribute `_labels_getter` of `SanitizeBoundingBoxes` + # because it points to a lambda function, which does not have a comparison defined. + assert compare_transforms_attrs_excluding( + transform1=train_tr, + transform2=expected_train_tr, + keys_to_skip=["_labels_getter"], + ) + + +@pytest.mark.parametrize( + "crabs_data_module, expected_test_val_transforms", + [ + ( + "crabs_data_module_with_data_augm", + "expected_no_data_augm_transforms", + ), + ( + "crabs_data_module_without_data_augm", + "expected_no_data_augm_transforms", + ), + ], +) +def test_get_test_val_transform( + crabs_data_module, expected_test_val_transforms, request +): + crabs_data_module = request.getfixturevalue(crabs_data_module) + expected_test_val_transforms = request.getfixturevalue( + expected_test_val_transforms + ) + test_val_transforms = crabs_data_module._get_test_val_transform() -def test_get_test_transform(crabs_data_module, expected_transforms_test_set): - test_transform = crabs_data_module._get_test_val_transform() - assert isinstance(test_transform, transforms.Compose) + assert isinstance(test_val_transforms, transforms.Compose) - assert len(test_transform.transforms) == len(expected_transforms_test_set) - for transform, expected_transform in zip( - test_transform.transforms, expected_transforms_test_set + # assert all transforms in Compose have same attributes + for test_val_tr, expected_test_val_tr in zip( + test_val_transforms.transforms, + expected_test_val_transforms.transforms, ): - assert isinstance(transform, type(expected_transform)) + assert test_val_tr.__dict__ == expected_test_val_tr.__dict__ @pytest.fixture @@ -97,7 +188,15 @@ def dummy_dataset(): return images, annotations -def test_collate_fn(crabs_data_module, dummy_dataset): +@pytest.mark.parametrize( + "crabs_data_module", + [ + "crabs_data_module_with_data_augm", + "crabs_data_module_without_data_augm", + ], +) +def test_collate_fn(crabs_data_module, dummy_dataset, request): + crabs_data_module = request.getfixturevalue(crabs_data_module) collated_data = crabs_data_module._collate_fn(dummy_dataset) assert len(collated_data) == len(dummy_dataset[0]) # images