Skip to content

Commit

Permalink
Enable MCore checkpointing optimizations (NVIDIA#9505)
Browse files Browse the repository at this point in the history
* Expose num processes in PyT Dist

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add parallel save/load optimizations from MCore

Signed-off-by: Mikołaj Błaż <[email protected]>

* Remove async utils from MCore

Signed-off-by: Mikołaj Błaż <[email protected]>

* Enable DistOpt paralell R/W

Signed-off-by: Mikołaj Błaż <[email protected]>

* Enable PyT Dist caching

Signed-off-by: Mikołaj Błaż <[email protected]>

* Small fixes

Signed-off-by: Mikołaj Błaż <[email protected]>

* Make sure DistCkptIO is instantiated from config

Signed-off-by: Mikołaj Błaż <[email protected]>

* Bump MCore version to v0.7

Signed-off-by: Mikołaj Błaż <[email protected]>

* Print load strategy

Signed-off-by: Mikołaj Błaż <[email protected]>

* Forward MCore to model space DistOpt

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add separate flag to control DistOpt paralell R/W

Signed-off-by: Mikołaj Błaż <[email protected]>

* Turn off parallel save by default

Signed-off-by: Mikołaj Błaż <[email protected]>

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
  • Loading branch information
mikolajblaz authored Jul 5, 2024
1 parent 7e10458 commit 18ecd41
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 327 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0
ARG MCORE_TAG=0ab8dd4c7520408683fdb9f8ac119eff7d38fc0e
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ model:
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files.

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=self.cfg.model.get('sharp', False),
dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False),
dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_dist_opt', True),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def dummy():
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io = DistributedCheckpointIO.from_config(model.cfg, async_save=False)
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

if HAVE_MODELOPT and hasattr(model, "get_model_module_list"):
Expand Down
4 changes: 1 addition & 3 deletions nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def load_state_dict(self, state_dict):
def sharded_state_dict(
self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False
):
# TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore.
# sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
sharding_type = 'dp_zero_gather_scatter'
sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
return self.mcore_optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
)
Expand Down
106 changes: 83 additions & 23 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,29 @@
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.serialization import (
get_default_load_sharded_strategy,
get_default_save_sharded_strategy,
)
from megatron.core.dist_checkpointing.strategies import tensorstore

from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.parallel_state import get_data_parallel_group

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError) as IMPORT_ERROR_EXC:
except (ImportError, ModuleNotFoundError) as e:

HAVE_MEGATRON_CORE = False
IMPORT_ERROR = "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
IMPORT_ERROR = (
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
f" Exact error: {e}"
)


@contextmanager
Expand Down Expand Up @@ -87,7 +100,7 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO):

def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None:
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC
raise ImportError(IMPORT_ERROR)
if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO):
raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}')

Expand Down Expand Up @@ -177,22 +190,38 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
always loads on device). Defaults to True.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
used during ckpt save with PyTorch distributed format. Defaults, to None
which means using an MCore default (2).
parallel_save (bool): parallelizes the save across ranks. Defaults to True
parallel_load (bool): parallelizes the load across ranks (followed by params all gather).
Defaults to False due to some extra memory usage requirement.
"""

def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = True,
parallel_load: bool = False,
):
super().__init__()
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC
raise ImportError(IMPORT_ERROR)

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.async_save = async_save
self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_load = parallel_load

self._save_sharded_strategy = None
self.validated_consistency = False

@classmethod
def from_config(cls, model_cfg: dict, async_save: bool = False):
Expand All @@ -208,6 +237,9 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', True),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)

@_debug_time('DistributedCheckpointIO.save_checkpoint')
Expand All @@ -224,16 +256,15 @@ def save_checkpoint(
fs = get_filesystem(path)
fs.makedirs(path, exist_ok=True)

dist_checkpointing.save(
sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy
validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True
return dist_checkpointing.save(
sharded_state_dict=checkpoint,
checkpoint_dir=path,
sharded_strategy=self.save_sharded_strategy,
validate_access_integrity=validate_sharding_integrity,
async_sharded_save=self.async_save,
)
if not self.async_save:
return None
# NOTE: this logic will be simplified in MCore v0.7
assert self.save_sharded_strategy.async_request is not None
async_request = self.save_sharded_strategy.async_request
self.save_sharded_strategy.async_request = None
return async_request

@_debug_time('DistributedCheckpointIO.load_checkpoint')
def load_checkpoint(
Expand Down Expand Up @@ -267,6 +298,16 @@ def load_checkpoint(
else:
sharded_strategy = None

if self.parallel_load:
if sharded_strategy is None:
sharded_strategy = get_default_load_sharded_strategy(path)
sharded_strategy = FullyParallelLoadStrategyWrapper(
sharded_strategy, get_data_parallel_group(with_context_parallel=True)
)

if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)

Expand Down Expand Up @@ -309,17 +350,36 @@ def remove_checkpoint(self, path: _PATH) -> None:
"""
shutil.rmtree(path, ignore_errors=True)

@property
def save_sharded_strategy(self) -> 'SaveShardedStrategy':
if self._save_sharded_strategy is None:
self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
return self._save_sharded_strategy

def _determine_dist_ckpt_save_strategy(self):
"""Determine the saving strategy based on constructor args.
If self.async_save is True instantiates an async PyT Dist strategy,
otherwise relies on MCore to create a proper strategy based on ckpt format.
Relies on the default MCore strategy unless extra PyT Distributed format arguments
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
save_strategy = (self.save_ckpt_format, 1)
if self.async_save:
if save_strategy[0] != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')
save_strategy = TorchDistAsyncSaveShardedStrategy('torch_dist', 1)
if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc)
if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs:
save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs)
else:
save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1)

# MCore v0.8 introduces `use_cached_ckpt_structure` attribute
if hasattr(save_strategy, 'use_cached_ckpt_structure'):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
return save_strategy
Loading

0 comments on commit 18ecd41

Please sign in to comment.