From 312dcbb36dde10090c2df5474c3c6d7b1ec70549 Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Fri, 15 Nov 2024 17:48:59 -0500 Subject: [PATCH] add enable_/disable_training util functions to wasp_em_fine_tuning/train_utils.py --- wasp_em_fine_tuning/train_image_predictor.py | 2 +- wasp_em_fine_tuning/train_utils.py | 41 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 wasp_em_fine_tuning/train_utils.py diff --git a/wasp_em_fine_tuning/train_image_predictor.py b/wasp_em_fine_tuning/train_image_predictor.py index 6ccd5ddf..86fb5c6e 100644 --- a/wasp_em_fine_tuning/train_image_predictor.py +++ b/wasp_em_fine_tuning/train_image_predictor.py @@ -24,9 +24,9 @@ from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor -from sam2.utils.training import enable_training, disable_training from data_utils import SegmentationImageSampler, get_batch_with_prompts +from train_utils import enable_training, disable_training RngInitType = int | str | np.random.Generator diff --git a/wasp_em_fine_tuning/train_utils.py b/wasp_em_fine_tuning/train_utils.py new file mode 100644 index 00000000..55563cc4 --- /dev/null +++ b/wasp_em_fine_tuning/train_utils.py @@ -0,0 +1,41 @@ +import logging +import os +import torch + +TRAINING_MODE_VAR = "SAM2_TRAINING_ENABLED" +TRUE = "true" +FALSE = "false" + + +def is_training_enabled() -> bool: + return os.environ.get(TRAINING_MODE_VAR, FALSE).lower() == TRUE + + +def enable_training(): + logging.info("Enabling training mode") + os.environ[TRAINING_MODE_VAR] = TRUE + + +def disable_training(): + logging.info("Disabling training mode") + os.environ[TRAINING_MODE_VAR] = FALSE + + +class no_grad_if_not_training(torch.no_grad): + def __enter__(self) -> None: + if not is_training_enabled(): + super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if not is_training_enabled(): + super().__exit__(exc_type, exc_val, exc_tb) + + +class inference_mode_if_not_training(torch.inference_mode): + def __enter__(self) -> None: + if not is_training_enabled(): + super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if not is_training_enabled(): + super().__exit__(exc_type, exc_val, exc_tb)