Skip to content

Commit

Permalink
Initial version for multinode auto_runner and ensembler (#6272)
Browse files Browse the repository at this point in the history
Fixes #6191 #6259 .

### Description
Big changes over autorunner to enable multinode training and
multinode-multiGPU ensembler
Multiple changes:
1. Add set_device_info() to create a self.device_dict to define device
information (CUDA_VISIBLE_DEVICES, NUM_NODE, e.t.c.) for all parts in
autorunner, including data analyzer, trainer, ensembler. No global env
variable is set, all device info is from self.device_dict. Changes to
bundlegen is made.
2. To enable multi-gpu/multi-node training for ensembler (call from
subprocess), we need to separate the ensembler from autorunner (for
subprocess to run from autorunner). Created a new EnsembleRunner class
(similar to bundleGen), and moved all ensemble related function from
autorunner to this class. Local multi-GPU ensembling passed.

Passed some quick local testing. Needs to fix details and do test.
Created PR to do a initial design pattern discussion. Slack me if there
is any major concern of the change.
@mingxin-zheng @wyli

---------

Signed-off-by: heyufan1995 <[email protected]>
  • Loading branch information
heyufan1995 authored Apr 14, 2023
1 parent 3633b1c commit 825b8db
Show file tree
Hide file tree
Showing 7 changed files with 485 additions and 164 deletions.
8 changes: 7 additions & 1 deletion monai/apps/auto3dseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from .auto_runner import AutoRunner
from .bundle_gen import BundleAlgo, BundleGen
from .data_analyzer import DataAnalyzer
from .ensemble_builder import AlgoEnsemble, AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder
from .ensemble_builder import (
AlgoEnsemble,
AlgoEnsembleBestByFold,
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
EnsembleRunner,
)
from .hpo_gen import NNIGen, OptunaGen
from .utils import export_bundle_algo_history, import_bundle_algo_history
3 changes: 2 additions & 1 deletion monai/apps/auto3dseg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from monai.apps.auto3dseg.auto_runner import AutoRunner
from monai.apps.auto3dseg.bundle_gen import BundleAlgo, BundleGen
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder
from monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder, EnsembleRunner
from monai.apps.auto3dseg.hpo_gen import NNIGen, OptunaGen

if __name__ == "__main__":
Expand All @@ -27,6 +27,7 @@
"BundleGen": BundleGen,
"BundleAlgo": BundleAlgo,
"AlgoEnsembleBuilder": AlgoEnsembleBuilder,
"EnsembleRunner": EnsembleRunner,
"AutoRunner": AutoRunner,
"NNIGen": NNIGen,
"OptunaGen": OptunaGen,
Expand Down
222 changes: 110 additions & 112 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,17 @@
from time import sleep
from typing import Any, cast

import numpy as np
import torch

from monai.apps.auto3dseg.bundle_gen import BundleGen
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.apps.auto3dseg.ensemble_builder import (
AlgoEnsemble,
AlgoEnsembleBestByFold,
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
)
from monai.apps.auto3dseg.ensemble_builder import EnsembleRunner
from monai.apps.auto3dseg.hpo_gen import NNIGen
from monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history
from monai.apps.utils import get_logger
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle import ConfigParser
from monai.transforms import SaveImage
from monai.utils.enums import AlgoKeys
from monai.utils.module import look_up_option, optional_import
from monai.utils import AlgoKeys, has_option, look_up_option, optional_import

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -232,6 +224,7 @@ def __init__(
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
input = self.data_src_cfg_name
Expand Down Expand Up @@ -285,16 +278,11 @@ def __init__(
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
self.set_device_info()
self.set_prediction_params()
self.set_analyze_params()

self.save_image = self.set_image_save_transform(kwargs)

self.ensemble_method: AlgoEnsemble
self.ensemble_method_name: str | None = None

self.set_ensemble_method()
self.set_num_fold(num_fold=num_fold)
self.set_ensemble_method("AlgoEnsembleBestByFold")

self.gpu_customization = False
self.gpu_customization_specs: dict[str, Any] = {}
Expand Down Expand Up @@ -461,18 +449,11 @@ def set_num_fold(self, num_fold: int = 5) -> None:
Args:
num_fold: a positive integer to define the number of folds.
Notes:
If the ensemble method is ``AlgoEnsembleBestByFold``, this function automatically updates the ``n_fold``
parameter in the ``ensemble_method`` to avoid inconsistency between the training and the ensemble.
"""

if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")

self.num_fold = num_fold
if self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method.n_fold = self.num_fold # type: ignore

def set_training_params(self, params: dict[str, Any] | None = None) -> None:
"""
Expand All @@ -488,6 +469,95 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
"""
self.train_params = deepcopy(params) if params is not None else {}
if "CUDA_VISIBLE_DEVICES" in self.train_params:
warnings.warn(
"CUDA_VISIBLE_DEVICES is deprecated from 'set_training_params'. Use 'set_device_info' intead.",
DeprecationWarning,
)

def set_device_info(
self,
cuda_visible_devices: list[int] | str | None = None,
num_nodes: int | None = None,
mn_start_method: str | None = None,
cmd_prefix: str | None = None,
) -> None:
"""
Set the device related info
Args:
cuda_visible_device: define GPU ids for data analyzer, training, and ensembling.
List of GPU ids [0,1,2,3] or a string "0,1,2,3".
Default using env "CUDA_VISIBLE_DEVICES" or all devices available.
num_nodes: number of nodes for training and ensembling.
Default using env "NUM_NODES" or 1 if "NUM_NODES" is unset.
mn_start_method: multi-node start method. Autorunner will use the method to start multi-node processes.
Default using env "MN_START_METHOD" or 'bcprun' if "MN_START_METHOD" is unset.
cmd_prefix: command line prefix for subprocess running in BundleAlgo and EnsembleRunner.
Default using env "CMD_PREFIX" or None, examples are:
- single GPU/CPU or multinode bcprun: "python " or "/opt/conda/bin/python3.8 ",
- single node multi-GPU running "torchrun --nnodes=1 --nproc_per_node=2 "
If user define this prefix, please make sure --nproc_per_node matches cuda_visible_device or
os.env['CUDA_VISIBLE_DEVICES]. Also always set --nnodes=1. Set num_nodes for multi-node.
"""
self.device_setting: dict[str, Any] = {}
if cuda_visible_devices is None:
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices is None: # still None after reading the environ
self.device_setting["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in range(torch.cuda.device_count())])
self.device_setting["n_devices"] = torch.cuda.device_count()
elif isinstance(cuda_visible_devices, str):
self.device_setting["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
self.device_setting["n_devices"] = len(cuda_visible_devices.split(","))
elif isinstance(cuda_visible_devices, (list, tuple)):
self.device_setting["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in cuda_visible_devices])
self.device_setting["n_devices"] = len(cuda_visible_devices)
else:
logger.warn(f"Wrong format of cuda_visible_devices {cuda_visible_devices}, devices not set")

if num_nodes is None:
num_nodes = int(os.environ.get("NUM_NODES", 1))
self.device_setting["NUM_NODES"] = num_nodes

if mn_start_method is None:
mn_start_method = os.environ.get("MN_START_METHOD", "bcprun")
self.device_setting["MN_START_METHOD"] = mn_start_method

if cmd_prefix is None:
cmd_prefix = os.environ.get("CMD_PREFIX")
self.device_setting["CMD_PREFIX"] = cmd_prefix

if cmd_prefix is not None:
logger.info(f"Using user defined command running prefix {cmd_prefix}, will overide other settings")

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
"""
Set the bundle ensemble method name and parameters for save image transform parameters.
Args:
params: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
and "AlgoEnsembleBestByFold".
kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
``AlgoEnsembleBestN`` is supported.
"""
self.ensemble_method_name = look_up_option(
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
self.kwargs.update(kwargs)

def set_image_save_transform(self, **kwargs: Any) -> None:
"""
Set the ensemble output transform.
Args:
kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
"""

self.kwargs.update(kwargs)

def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
"""
Expand Down Expand Up @@ -547,10 +617,7 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
Users can set ``nni_dry_run`` to ``True`` in the ``params`` to enable the dry-run mode for the NNI backend.
"""
if params is None:
self.hpo_params = self.train_params
else:
self.hpo_params = params
self.hpo_params = self.train_params if params is None else params

def set_nni_search_space(self, search_space):
"""
Expand All @@ -569,58 +636,6 @@ def set_nni_search_space(self, search_space):
self.search_space = search_space
self.hpo_tasks = value_combinations

def set_image_save_transform(self, kwargs):
"""
Set the ensemble output transform.
Args:
kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .
"""

if "output_dir" in kwargs:
output_dir = kwargs.pop("output_dir")
else:
output_dir = os.path.join(self.work_dir, "ensemble_output")
logger.info(f"The output_dir is not specified. {output_dir} will be used to save ensemble predictions")

if not os.path.isdir(output_dir):
os.makedirs(output_dir)
logger.info(f"Directory {output_dir} is created to save ensemble predictions")

self.output_dir = output_dir
output_postfix = kwargs.pop("output_postfix", "ensemble")
output_dtype = kwargs.pop("output_dtype", np.uint8)
resample = kwargs.pop("resample", False)

return SaveImage(
output_dir=output_dir, output_postfix=output_postfix, output_dtype=output_dtype, resample=resample, **kwargs
)

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
"""
Set the bundle ensemble method
Args:
ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
and "AlgoEnsembleBestByFold".
kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
``AlgoEnsembleBestN`` is supported.
"""
self.ensemble_method_name = look_up_option(
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
if self.ensemble_method_name == "AlgoEnsembleBestN":
n_best = kwargs.pop("n_best", False)
n_best = 2 if not n_best else n_best
self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold)
else:
raise NotImplementedError(f"Ensemble method {self.ensemble_method_name} is not implemented.")

def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
"""
Train the Algos in a sequential scheme. The order of training is randomized.
Expand All @@ -637,7 +652,10 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
"""
for algo_dict in history:
algo = algo_dict[AlgoKeys.ALGO]
algo.train(self.train_params)
if has_option(algo.train, "device_setting"):
algo.train(self.train_params, self.device_setting)
else:
algo.train(self.train_params)
acc = algo.get_score()

algo_meta_data = {str(AlgoKeys.SCORE): acc}
Expand Down Expand Up @@ -773,7 +791,7 @@ def run(self):

if auto_train_choice:
skip_algos = [h[AlgoKeys.ID] for h in history if h[AlgoKeys.IS_TRAINED]]
if len(skip_algos) > 0:
if skip_algos:
logger.info(
f"Skipping already trained algos {skip_algos}."
"Set option train=True to always retrain all algos."
Expand All @@ -792,34 +810,14 @@ def run(self):

# step 4: model ensemble and write the prediction to disks.
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h[AlgoKeys.IS_TRAINED]]

if len(history) == 0:
raise ValueError(
f"Could not find any trained algos in {self.work_dir}. "
"Possibly the required training step was not completed."
)

builder = AlgoEnsembleBuilder(history, self.data_src_cfg_name)
builder.set_ensemble_method(self.ensemble_method)

ensembler = builder.get_ensemble()
preds = ensembler(pred_param=self.pred_params)
if len(preds) > 0:
logger.info("Auto3Dseg picked the following networks to ensemble:")
for algo in ensembler.get_algo_ensemble():
logger.info(algo[AlgoKeys.ID])

for pred in preds:
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

ensemble_runner = EnsembleRunner(
data_src_cfg_name=self.data_src_cfg_name,
work_dir=self.work_dir,
num_fold=self.num_fold,
ensemble_method_name=self.ensemble_method_name,
mgpu=int(self.device_setting["n_devices"]) > 1,
**self.kwargs, # for set_image_save_transform
**self.pred_params,
) # for inference
ensemble_runner.run(self.device_setting)
logger.info("Auto3Dseg pipeline is completed successfully.")
Loading

0 comments on commit 825b8db

Please sign in to comment.