From 0f157abd4813bf488488adc52d3172742fa58b9c Mon Sep 17 00:00:00 2001 From: Anna Shors <71393111+ashors1@users.noreply.github.com> Date: Thu, 4 Jul 2024 01:00:38 -0700 Subject: [PATCH 01/13] [NeMo-UX] Dataloading enhancements and bug fixes (#9595) * fix dataloading + checkpoint restore * clean up data sampler * fix typo * support passing multiple paths to data module * fix validation dataloader * fix dataloader len when using gradient accumulation * fix progress bar * Apply isort and black reformatting Signed-off-by: ashors1 * fix step count in loggers * fix blended dataset * address comments * address comment * move step logging into strategy * Apply isort and black reformatting Signed-off-by: ashors1 --------- Signed-off-by: ashors1 Co-authored-by: Marc Romeyn Co-authored-by: ashors1 --- nemo/collections/llm/gpt/data/pre_training.py | 65 ++++++++++++++++--- nemo/collections/llm/gpt/model/base.py | 1 - nemo/lightning/data.py | 7 +- nemo/lightning/pytorch/callbacks/progress.py | 8 +-- .../lightning/pytorch/plugins/data_sampler.py | 7 +- nemo/lightning/pytorch/strategies.py | 5 ++ 6 files changed, 72 insertions(+), 21 deletions(-) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 18ce781f1409..247ee1a1521a 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -17,7 +17,8 @@ class PreTrainingDataModule(pl.LightningDataModule): def __init__( self, - path: Path, + paths: Path | List[Path], + weights: Optional[List[float]] = None, seq_length: int = 2048, tokenizer: Optional["TokenizerSpec"] = None, micro_batch_size: int = 4, @@ -37,7 +38,13 @@ def __init__( index_mapping_dir: Optional[str] = None, ) -> None: super().__init__() - self.path = path + if not isinstance(paths, (list, tuple)): + paths = [paths] + if weights is not None: + assert len(weights) == len(paths) + + self.paths = paths + self.weights = weights self.seq_length = seq_length self.tokenizer = tokenizer self.num_train_samples = num_train_samples @@ -52,6 +59,7 @@ def __init__( self.seed = seed self.split = split self.index_mapping_dir = index_mapping_dir + self.init_global_step = 0 from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer @@ -76,13 +84,13 @@ def setup(self, stage: str = "") -> None: assert max_train_steps > 0, "Please specify trainer.max_steps" eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches test_iters = self.trainer.limit_test_batches - num_train_samples = max_train_steps * self.data_sampler.global_batch_size - num_val_samples = eval_iters * self.data_sampler.global_batch_size - num_test_samples = test_iters * self.data_sampler.global_batch_size + num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) + num_val_samples = int(eval_iters * self.data_sampler.global_batch_size) + num_test_samples = int(test_iters * self.data_sampler.global_batch_size) if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): # This is to make sure we only have one epoch on every validation iteration - num_val_samples = 1 + num_val_samples = None train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples] self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( @@ -119,6 +127,7 @@ def test_dataloader(self) -> EVAL_DATALOADERS: return self._create_dataloader(self._test_ds) def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + self.init_global_step = self.trainer.global_step return DataLoader( dataset, num_workers=self.num_workers, @@ -133,7 +142,7 @@ def gpt_dataset_config(self) -> "GPTDatasetConfig": from megatron.core.datasets.gpt_dataset import GPTDatasetConfig return GPTDatasetConfig( - blend=[[str(self.path)], [1.0]], + blend=[[str(path) for path in self.paths], self.weights], random_seed=self.seed, sequence_length=self.seq_length, tokenizer=self.tokenizer, @@ -143,3 +152,43 @@ def gpt_dataset_config(self) -> "GPTDatasetConfig": reset_attention_mask=self.reset_attention_mask, eod_mask_loss=self.eod_mask_loss, ) + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + + """ + consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step) + return {'consumed_samples': consumed_samples} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat + + Args: + state_dict: the datamodule state returned by ``state_dict``. + + """ + try: + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + except ModuleNotFoundError: + from nemo.lightning.apex_utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + consumed_samples = state_dict['consumed_samples'] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + num_microbatch_calculator = _GLOBAL_NUM_MICROBATCHES_CALCULATOR # noqa: SLF001 + + num_microbatch_calculator.update( + consumed_samples=consumed_samples, + consistency_check=False, + ) + current_global_batch_size = num_microbatch_calculator.current_global_batch_size + '''pl_module.log( + "global_batch_size", + current_global_batch_size, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + )''' + self.if_first_step = 1 diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index d6bf876f0a3d..9b7f4e4ab0c8 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -156,7 +156,6 @@ def forward_step(self, batch) -> torch.Tensor: def training_step(self, batch, batch_idx=None) -> torch.Tensor: # In mcore the loss-function is part of the forward-pass (when labels are provided) - return self.forward_step(batch) def validation_step(self, batch, batch_idx=None) -> torch.Tensor: diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index adfc0aa14d29..d83f5ba3b728 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -183,9 +183,12 @@ def __len__(self): num_available_samples: int = self.total_samples - self.consumed_samples if self.global_batch_size is not None: if self.drop_last: - return num_available_samples // self.global_batch_size + num_global_batches = num_available_samples // self.global_batch_size else: - return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and + # num of batches fetched (as training step fetches in terms of micro batches) + return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) else: return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 diff --git a/nemo/lightning/pytorch/callbacks/progress.py b/nemo/lightning/pytorch/callbacks/progress.py index 9d4d9b385da8..17178618852f 100644 --- a/nemo/lightning/pytorch/callbacks/progress.py +++ b/nemo/lightning/pytorch/callbacks/progress.py @@ -26,19 +26,13 @@ def init_train_tqdm(self): return self.bar def on_train_epoch_start(self, trainer, *_): - if trainer.max_steps > 0 and (trainer.ckpt_path is not None): + if trainer.max_steps > 0: # and (trainer.ckpt_path is not None): # while resuming from a ckpt use trainer.max_steps as the total for progress bar as trainer.num_training_batches # is truncated to max_steps - step being resumed at num_training_batches = trainer.max_steps else: num_training_batches = trainer.num_training_batches - # from nemo.utils import AppState - # app_state = AppState() - # app_state. - - num_training_batches = num_training_batches // calculate_data_parallel_groups() - self.train_progress_bar.reset(num_training_batches) self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index c6ff3b7ccaaa..378375e3bc0c 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -23,14 +23,15 @@ def __init__( global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, dataloader_type: Literal["single", "cyclic"] = "single", + init_consumed_samples: int = 0, ): self.seq_len = seq_len self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.rampup_batch_size = rampup_batch_size self.dataloader_type = dataloader_type - self.init_consumed_samples: int = 0 - self.prev_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.prev_consumed_samples = self.init_consumed_samples self.if_first_step = 0 self.prev_global_batch_size = None @@ -47,7 +48,7 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0 micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, rampup_batch_size=self.rampup_batch_size, - consumed_samples=consumed_samples, + consumed_samples=self.init_consumed_samples, dataloader_type=self.dataloader_type, ) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 6095ee04a02a..99e7245d60dd 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -352,6 +352,11 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP batch_size=1, ) + self.lightning_module.log( + 'step', + self.trainer.global_step, + ) + if self.log_memory_usage: max_memory_reserved = torch.cuda.max_memory_reserved() memory_allocated = torch.cuda.memory_allocated() From 32286ed430a8bb6af97688f3b68be5fd2af1101e Mon Sep 17 00:00:00 2001 From: Sara Rabhi Date: Thu, 4 Jul 2024 10:04:45 -0400 Subject: [PATCH 02/13] Fix serialization of AutoResume (#9616) * fix serialization of autoresume * update undefined variables --- nemo/lightning/resume.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index fc4f7ec9fab8..f762d345ed3b 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -4,8 +4,10 @@ import lightning_fabric as fl import pytorch_lightning as pl +from nemo.lightning import io from nemo.utils import logging from nemo.utils.app_state import AppState +from nemo.utils.model_utils import uninject_model_parallel_rank class Resume: @@ -22,7 +24,7 @@ def setup(self, model, trainer: Union[pl.Trainer, fl.Fabric]): trainer.checkpoint_callback.last_model_path = ckpt_path -class AutoResume(Resume): +class AutoResume(Resume, io.IOMixin): """Class that handles the logic for setting checkpoint paths and restoring from checkpoints in NeMo. """ @@ -101,15 +103,15 @@ def nemo_path(self, model=None) -> Optional[Path]: warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " if checkpoint is None: warn += "Training from scratch." - elif checkpoint == resume_from_checkpoint: - warn += f"Training from {resume_from_checkpoint}." + elif checkpoint == self.path: + warn += f"Training from {self.path}." logging.warning(warn) else: raise NotFoundError( f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: - if resume_past_end: + if self.resume_past_end: if len(end_checkpoints) > 1: if 'mp_rank' in str(end_checkpoints[0]): checkpoint = end_checkpoints[0] From bf8273790170cfd4147d5e02bce0c5135e7eefee Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Thu, 4 Jul 2024 11:51:42 -0700 Subject: [PATCH 03/13] Chat template support for megatron_gpt_eval.py (#9354) * Bump PTL version (#9557) Signed-off-by: Abhishree Signed-off-by: Alexandros Koumparoulis * [Resiliency] Straggler detection (#9473) * Initial straggler det impl Signed-off-by: Jacek Bieniusiewicz * Apply isort and black reformatting Signed-off-by: jbieniusiewi * Fixed CI code checks Signed-off-by: Jacek Bieniusiewicz * Apply isort and black reformatting Signed-off-by: jbieniusiewi * Removed unused import Signed-off-by: Jacek Bieniusiewicz * remove submodule Signed-off-by: Maanu Grover * Updated documentation; Updated callback params; Cosmetic changes Signed-off-by: Jacek Bieniusiewicz * Apply isort and black reformatting Signed-off-by: jbieniusiewi * Fixed straggler det config; Added basic test Signed-off-by: Jacek Bieniusiewicz * Apply isort and black reformatting Signed-off-by: jbieniusiewi * Fixes in test_straggler_det.py Signed-off-by: Jacek Bieniusiewicz * Updated straggler callback API Signed-off-by: Jacek Bieniusiewicz * Apply isort and black reformatting Signed-off-by: jbieniusiewi * stop_if_detected=False by default Signed-off-by: Jacek Bieniusiewicz --------- Signed-off-by: Jacek Bieniusiewicz Signed-off-by: jbieniusiewi Signed-off-by: Maanu Grover Co-authored-by: jbieniusiewi Co-authored-by: Maanu Grover Signed-off-by: Alexandros Koumparoulis * move model loading to separate function; call toContainer once; pad using closed formula Signed-off-by: Alexandros Koumparoulis * read prompts from file Signed-off-by: Alexandros Koumparoulis * If input prompt contains dict, apply model.tokenizer.chat_template Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis * apply @Gal Leibovich's patch Taken from: https://github.com/NVIDIA/NeMo/commit/17572905344db4692583e72799d55801a8860f35 Signed-off-by: Alexandros Koumparoulis * rename prompts_file to prompts_jsonl Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis * add chat_template param Signed-off-by: Alexandros Koumparoulis * Add ChatTemplateMixin to SentencePieceTokenizer Signed-off-by: Alexandros Koumparoulis * add chat-template to text-gen-strat Signed-off-by: Alexandros Koumparoulis * move load prompts to separate file Signed-off-by: Alexandros Koumparoulis * remove chat-template from text-gen-utils Signed-off-by: Alexandros Koumparoulis * make chat-template more generic Signed-off-by: Alexandros Koumparoulis * add assert message Signed-off-by: Alexandros Koumparoulis * small refactor for chat_template_mixin Signed-off-by: Alexandros Koumparoulis * undo ckpt conv changes Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis * move rounding to function Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Abhishree Signed-off-by: Alexandros Koumparoulis Signed-off-by: Jacek Bieniusiewicz Signed-off-by: jbieniusiewi Signed-off-by: Maanu Grover Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: jbieniusiewi <152396322+jbieniusiewi@users.noreply.github.com> Co-authored-by: jbieniusiewi Co-authored-by: Maanu Grover Co-authored-by: akoumpa --- docs/source/core/exp_manager.rst | 42 ++++ .../conf/megatron_gpt_inference.yaml | 1 + .../language_modeling/megatron_gpt_eval.py | 77 +++++--- .../common/tokenizers/chat_template_mixin.py | 179 ++++++++++++++++++ .../tokenizers/sentencepiece_tokenizer.py | 18 +- .../language_modeling/megatron_base_model.py | 1 + .../common/text_generation_strategy.py | 9 +- .../modules/common/text_generation_utils.py | 45 ++--- .../nlp/modules/common/tokenizer_utils.py | 17 +- 9 files changed, 334 insertions(+), 55 deletions(-) create mode 100644 nemo/collections/common/tokenizers/chat_template_mixin.py diff --git a/docs/source/core/exp_manager.rst b/docs/source/core/exp_manager.rst index e813b8f16ac4..ce5f7a9cb087 100644 --- a/docs/source/core/exp_manager.rst +++ b/docs/source/core/exp_manager.rst @@ -248,6 +248,48 @@ You might also want to adjust the callback parameters: Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). +.. _exp_manager_straggler_det_support-label: + +.. note:: + Stragglers Detection feature is included in the optional NeMo resiliency package. + +Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. +NeMo provides a straggler detection feature that can identify slower GPUs. + +This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + +The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). +A performance score can be interpreted as the ratio of current performance to reference performance. + +There are two types of performance scores provided by the callback: + - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. + - Individual GPU performance score: The best historical performance of the GPU is used as a reference. + +Examples: + - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. + - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. + +If a GPU performance score drops below the specified threshold, it is identified as a straggler. + +To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. +You might also want to adjust the callback parameters: + +.. code-block:: yaml + + exp_manager: + ... + create_straggler_detection_callback: True + straggler_detection_callback_params: + report_time_interval: 300 # Interval [seconds] of the straggler check + calc_relative_gpu_perf: True # Calculate relative GPU performance + calc_individual_gpu_perf: True # Calculate individual GPU performance + num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected + gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores + gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores + stop_if_detected: True # Terminate the workload if stragglers are detected + +Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). + Fault Tolerance --------------- diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml index 2570251bcdee..ce8311daf95c 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml @@ -31,6 +31,7 @@ hparams_file: null # model configuration file, only used for PTL checkpoint load prompts: # prompts for GPT inference - "Q: How are you?" - "Q: How big is the universe?" +prompts_jsonl: null server: False # whether launch the API server port: 5555 # the port number for the inference server web_server: False # whether launch the web inference server diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index f3413a5fa92e..362a2ae3e298 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -14,6 +14,7 @@ import asyncio import datetime +import json import os import threading from functools import partial @@ -166,20 +167,7 @@ def remove_padded_prompts(response, nb_paddings): return result -@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") -def main(cfg) -> None: - - callbacks = [] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) - # trainer required for restoring model parallel models - trainer = Trainer( - strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), - **cfg.trainer, - callbacks=callbacks, - ) - +def load_model_from_config(trainer, cfg): if cfg.gpt_model_file is not None: if ( cfg.tensor_model_parallel_size < 0 @@ -285,7 +273,50 @@ def main(cfg) -> None: model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) else: raise ValueError("need at least a nemo file or checkpoint dir") + return model + + +def load_prompts(cfg): + prompts = [] + if (cfg_prompts := getattr(cfg, 'prompts', None)) is not None: + prompts = OmegaConf.to_container(cfg_prompts) + if (prompts_jsonl := getattr(cfg, 'prompts_jsonl', None)) is not None: + with open(prompts_jsonl, 'rt') as fp: + try: + prompts += list(map(json.loads, map(str.rstrip, fp))) + except: + prompts += list(map(str.rstrip, fp)) + # Make sure non-empty input + assert len(prompts) > 0, "Expected at least one prompt" + # Make sure all have the same type + assert all( + map(lambda x: isinstance(x, type(prompts[0])), prompts) + ), "Expected all prompts to have the same datatype" + return prompts + + +def round_to_mult(n, mult=8): + """ + Rounds number n to be a multiple of mult + """ + return ((n + mult - 1) // mult) * mult + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") +def main(cfg) -> None: + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + model = load_model_from_config(trainer, cfg) model.freeze() # Have to turn off activations_checkpoint_method for inference @@ -311,17 +342,17 @@ def main(cfg) -> None: "end_strings": cfg.inference.end_strings, } + prompts = load_prompts(cfg) + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) - if fp8_enabled: - nb_paddings = 0 - while len(cfg.prompts) % 8 != 0: - cfg.prompts.append("") - nb_paddings += 1 + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings # First method of running text generation, call model.generate method - response = model.generate( - inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params - ) + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) if fp8_enabled: response = remove_padded_prompts(response, nb_paddings) @@ -331,7 +362,7 @@ def main(cfg) -> None: # Second method of running text generation, call trainer.predict [recommended] bs = 8 if fp8_enabled else 2 - ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + ds = RequestDataSet(prompts) request_dl = DataLoader(dataset=ds, batch_size=bs) config = OmegaConf.to_container(cfg.inference) model.set_inference_config(config) diff --git a/nemo/collections/common/tokenizers/chat_template_mixin.py b/nemo/collections/common/tokenizers/chat_template_mixin.py new file mode 100644 index 000000000000..83a5e537519c --- /dev/null +++ b/nemo/collections/common/tokenizers/chat_template_mixin.py @@ -0,0 +1,179 @@ +import re +from functools import cache + +TEMPLATE_VAR_VALIDATION_PAT = re.compile(r'^\{_[A-Za-z][A-Za-z0-9_]*_\}$') +TEMPLATE_VAR_SEARCH_PAT = re.compile('({_[^}]+_})') + + +class ChatTemplateMixin: + def apply_chat_template(self, messages): + assert self.chat_template is not None + return tokenize_with_chat_template(self, messages, self.chat_template) + + @property + def has_chat_template(self): + return self.chat_template is not None + + +@cache +def is_template_var(s): + # It should start with {_ and end with _}, be non-empty and not contain { or } within. + return re.match(TEMPLATE_VAR_VALIDATION_PAT, s) + + +def extract_template_parts(template, skip_empty=True): + for part in re.split(TEMPLATE_VAR_SEARCH_PAT, template): + # skip empty parts + if skip_empty and part == '': + continue + yield part + + +def strip_template_wrap(s): + if not is_template_var(s): + return s + # Strip the "{_" prefix and the "_}" suffix + return s[2:-2] + + +def render_chat_turn(message, template): + """Renders a chat turn based on template + + Args: + message (Dict) + e.g. {'role': ['user'], 'content': ['What is your favourite fruit?']}, + template (Str): + "[INST] {_content_} [/INST]", + + Returns: + (str, token_id/None): the template formatted message + e.g. + "[INST] What is your favourite fruit? [/INST]", None + """ + ans = [] + for i, template_part in enumerate(extract_template_parts(template)): + if is_template_var(template_part): + template_part = strip_template_wrap(template_part) + if template_part == 'content': + ans.append(message['content']) + else: + # assert i == len(template_parts) - 1, "unsupported" + yield ''.join(ans), template_part + ans = [] + else: + # Otherwise it is literal string + ans.append(template_part) + yield ''.join(ans), None + + +def encode_string_with_special_token(tokenizer, inputs, special_token): + """ + Tokenizes a string or a list of string into their corresponding token_ids + and appends (at the end) a special_token if present. + + Args: + tokenizer: (SPM) + inputs: (Str, List[Str]) + e.g. "Alex" or ["Alex", "nvidia"] + special_token: (Str): + e.g. "eos" + + Returns: + (list[int]): list of token_ids + e.g. + input="Alex", special_token="eos" + Alex->[3413] + eos->[2] + + Will return the following: + [3413, 2] + """ + ans = [] + if isinstance(inputs, str) and inputs != '': + ans += tokenizer.text_to_ids(inputs) + elif isinstance(inputs, list) and len(inputs) > 0: + ans += tokenizer.text_to_ids(''.join(inputs)) + if special_token is not None: + # TODO(@akoumparouli): limit which attributes user-defined string can query. + assert hasattr(tokenizer, special_token), f"Special_token {special_token} is not part of tokenizer" + ans += [getattr(tokenizer, special_token)] + return ans + + +def tokenize_with_chat_template(tokenizer, messages, template): + assert is_chat_input(messages), "Expected input to be chat-template" + assert len(messages) > 0, "Expected non-empty messages" + assert 'roles' in template, "Expected template to have key `roles`." + ans = [] + encode = lambda x, y: encode_string_with_special_token(tokenizer, x, y) + if 'prefix' in template: + for part, special_token in render_chat_turn('', template['prefix']): + ans += encode(part, special_token) + buffer = [] + for message in messages: + assert message['role'] in template['roles'], (message['role'], template['roles']) + msg_template = template['roles'][message['role']] + for templated_messages, special_token in render_chat_turn(message, msg_template): + buffer += [templated_messages] + if special_token is not None: + ans += encode(buffer, special_token) + buffer = [] + # handle tail + ans += encode(buffer, None) + assert len(ans) > 0, 'Expected non-empty output' + return ans + + +def extract_turns(messages, axis): + """ + a collated messages can have multiple chat messages in each dict, + this extracts (vertically) one of them, for example: + + messages = [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + ans = extract_turns(messages, axis=1) + + ans = [ + {'role': ['user'], 'content': ['What is your favourite fruit?']}, + {'role': ['assistant'], 'content': ["good squeeze of fresh lemon"]}, + {'role': ['user'], 'content': ['Do you have tomato salad recipes?']} + ] + """ + ans = [] + for turn in messages: + ans.append({k: v[axis] for k, v in turn.items()}) + return ans + + +def explode_chat_template_input(messages): + """ + Example input + [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + + Notice the 2D axis system of the messages variable, one for the list and one for each item in the list (i.e. + the 'content' contains multiple messages). + """ + assert isinstance(messages, list), "Expected messages to be a list" + assert len(messages) > 0, "Expected non empty messages" + assert all(map(lambda x: isinstance(x, dict), messages)), "Expected messages to contain dicts" + assert all( + map(lambda x: 'role' in x and 'content' in x, messages) + ), "Expected messages each dict to contain 'role' and 'content' fields" + n = len(messages[0]['role']) + assert all( + map(lambda x: len(x['role']) == n, messages) + ), "Expected all batch messages to contain equal number of roles in all turns" + for i in range(n): + yield extract_turns(messages, axis=i) + + +def is_chat_input(messages): + # TOOD(@akoumparouli): improve validation. + return isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 4a47f0e49b1e..00893b6f379f 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -20,13 +20,14 @@ import torch from nemo.collections.common.parts.utils import if_exist +from nemo.collections.common.tokenizers.chat_template_mixin import ChatTemplateMixin from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging __all__ = ['SentencePieceTokenizer', 'create_spt_model'] -class SentencePieceTokenizer(TokenizerSpec): +class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): """ Sentencepiecetokenizer https://github.com/google/sentencepiece. @@ -38,8 +39,13 @@ class SentencePieceTokenizer(TokenizerSpec): """ def __init__( - self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, legacy: bool = False + self, + model_path: str, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + legacy: bool = False, + chat_template: Optional[Dict] = None, ): + self.chat_template = chat_template if not model_path or not os.path.exists(model_path): raise ValueError(f"model_path: {model_path} is invalid") self.tokenizer = sentencepiece.SentencePieceProcessor() @@ -89,6 +95,14 @@ def text_to_tokens(self, text): return self.tokenizer.encode_as_pieces(text) def text_to_ids(self, text, sample_alpha=None): + if isinstance(text, str): + return self._text_to_ids(text, sample_alpha) + elif isinstance(text, list): + return self.apply_chat_template(text) + else: + raise ValueError(f"Expected either str or list input, but got {type(text)}") + + def _text_to_ids(self, text, sample_alpha=None): if self.legacy: ids = [] idx = 0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index ae659e757496..f7b53a95c19a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -431,6 +431,7 @@ def _build_tokenizer(self): special_tokens=self.cfg.tokenizer.get('special_tokens', None), trust_remote_code=self.cfg.tokenizer.get('trust_remote_code', False), legacy=legacy, + chat_template=getattr(self._cfg.tokenizer, "chat_template", None), ) if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index e8e2859e439f..238c01695f42 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -21,6 +21,8 @@ import torch from transformers import CLIPImageProcessor + +from nemo.collections.common.tokenizers.chat_template_mixin import explode_chat_template_input, is_chat_input from nemo.collections.nlp.modules.common.lm_utils import pad_batch from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids @@ -94,7 +96,12 @@ def tokenize_batch(self, sentences, max_len, add_BOS): Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor. """ tokenizer = self.model.tokenizer - if add_BOS: + if is_chat_input(sentences): + assert getattr( + tokenizer, 'has_chat_template', False + ), "Got chat-template input but tokenizer does not support chat template formating." + context_tokens = list(map(tokenizer.text_to_ids, explode_chat_template_input(sentences))) + elif add_BOS: context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences] elif hasattr(tokenizer.tokenizer, "get_prefix_tokens"): # chatglm: add tokenizer.gmask_id, tokenizer.sop_id diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 498d9e9a09da..cd02f5409679 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -122,31 +122,26 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para compute_prob_response = get_computeprob_response(tokenizer, response, inputs) return compute_prob_response - if isinstance(inputs, (list, tuple)): - if isinstance(inputs[0], (str, torch.Tensor)): - output = generate( - model, - inputs=inputs, - tokens_to_generate=length_params['max_length'], - all_probs=sampling_params['all_probs'], - compute_logprob=sampling_params['compute_logprob'], - temperature=sampling_params['temperature'], - add_BOS=sampling_params['add_BOS'], - top_k=sampling_params['top_k'], - top_p=sampling_params['top_p'], - greedy=sampling_params['use_greedy'], - repetition_penalty=sampling_params['repetition_penalty'], - end_strings=sampling_params['end_strings'], - min_tokens_to_generate=length_params['min_length'], - **strategy_args, - ) - return output - elif isinstance(inputs[0], dict): - raise NotImplementedError("json object not implemented") - else: - raise NotImplementedError("unknown type is not implemented") - else: - raise NotImplementedError("unknown type is not implemented") + if not isinstance(inputs, (list, tuple)): + raise NotImplementedError(f"unknown type {type(inputs)} is not implemented") + + output = generate( + model, + inputs=inputs, + tokens_to_generate=length_params['max_length'], + all_probs=sampling_params['all_probs'], + compute_logprob=sampling_params['compute_logprob'], + temperature=sampling_params['temperature'], + add_BOS=sampling_params['add_BOS'], + top_k=sampling_params['top_k'], + top_p=sampling_params['top_p'], + greedy=sampling_params['use_greedy'], + repetition_penalty=sampling_params['repetition_penalty'], + end_strings=sampling_params['end_strings'], + min_tokens_to_generate=length_params['min_length'], + **strategy_args, + ) + return output def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_params, inference_config, **strategy_args): diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 67c94ae5d608..d3ee69f75b25 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -78,6 +78,7 @@ def get_tokenizer( special_tokens: Optional[Dict[str, str]] = None, use_fast: Optional[bool] = False, bpe_dropout: Optional[float] = 0.0, + chat_template: Optional[Dict] = None, ): """ Args: @@ -91,7 +92,7 @@ def get_tokenizer( use_fast: (only for HuggingFace AutoTokenizer) set to True to use fast HuggingFace tokenizer bpe_dropout: (experimental) BPE dropout tries to corrupt the standard segmentation procedure of BPE to help - model better learn word compositionality and become robust to segmentation errors. + model better learn word compositionality and become robust to segmentation errors. It has emperically been shown to improve inference time BLEU scores. """ if special_tokens is None: @@ -116,7 +117,10 @@ def get_tokenizer( if tokenizer_name == 'sentencepiece': logging.info("tokenizer_model: " + str(tokenizer_model)) return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, special_tokens=special_tokens, legacy=True + model_path=tokenizer_model, + special_tokens=special_tokens, + legacy=True, + chat_template=chat_template, ) elif tokenizer_name == 'word': return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) @@ -151,6 +155,7 @@ def get_nmt_tokenizer( legacy: Optional[bool] = False, delimiter: Optional[str] = None, trust_remote_code: Optional[bool] = False, + chat_template: Optional[Dict] = None, ): """ Args: @@ -187,7 +192,9 @@ def get_nmt_tokenizer( elif library == 'sentencepiece': logging.info(f'Getting SentencePiece with model: {tokenizer_model}') return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, legacy=legacy + model_path=tokenizer_model, + legacy=legacy, + chat_template=chat_template, ) elif library == 'byte-level': logging.info(f'Using byte-level tokenization') @@ -209,7 +216,9 @@ def get_nmt_tokenizer( logging.info( f'Getting Megatron tokenizer for pretrained model name: {model_name}, custom vocab file: {vocab_file}, and merges file: {merges_file}' ) - return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file) + return get_tokenizer( + tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template + ) elif library == 'tabular': return TabularTokenizer(vocab_file, delimiter=delimiter) else: From d8624991996295d6ecfe31eff6cc55c30b632585 Mon Sep 17 00:00:00 2001 From: Aditya Vavre Date: Thu, 4 Jul 2024 14:10:51 -0700 Subject: [PATCH 04/13] Jsonl support (#9611) * Adding support to preprocess .jsonl and .jsonl.gz files in input directory Signed-off-by: adityavavre * Adding support to preprocess .jsonl and .jsonl.gz files in input directory Signed-off-by: adityavavre * Apply isort and black reformatting Signed-off-by: adityavavre --------- Signed-off-by: adityavavre Signed-off-by: adityavavre Co-authored-by: adityavavre --- .../preprocess_data_for_megatron.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/scripts/nlp_language_modeling/preprocess_data_for_megatron.py b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py index 945b9e7b68a2..e1f89182279b 100644 --- a/scripts/nlp_language_modeling/preprocess_data_for_megatron.py +++ b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py @@ -104,6 +104,7 @@ except ImportError: nltk_available = False + # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -221,10 +222,16 @@ def get_args(): help='What tokenizer library to use.', ) group.add_argument( - '--tokenizer-type', type=str, default=None, help='What type of tokenizer to use.', + '--tokenizer-type', + type=str, + default=None, + help='What type of tokenizer to use.', ) group.add_argument( - '--tokenizer-model', type=str, default=None, help='Path to tokenizer model.', + '--tokenizer-model', + type=str, + default=None, + help='Path to tokenizer model.', ) group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') group.add_argument('--files-filter', type=str, default='**/*.json*', help='files filter str') @@ -248,7 +255,7 @@ def get_args(): group.add_argument( '--preproc-folder', action='store_true', - help='If set, will preprocess all .json or .json.gz files into a single .bin and .idx file. Folder path provided via the --input arg', + help='If set, will preprocess all .json or .jsonl or json.gz or .jsonl.gz files into a single .bin and .idx file. Folder path provided via the --input arg', ) group.add_argument('--apply-ftfy', action='store_true', help='If set, will apply ftfy to the input text') args = parser.parse_args() @@ -272,14 +279,18 @@ def main(): args = get_args() startup_start = time.time() if args.preproc_folder: - print('Searching folder for .json or .json.gz files...') + print('Searching folder for .json or .jsonl or json.gz or .jsonl.gz files...') assert os.path.exists(args.input), f'Folder does not exist: {args.input}' json_files = (str(f) for f in pathlib.Path(args.input).glob(args.files_filter)) - json_files = [f for f in json_files if f.endswith('.json') or f.endswith('.json.gz')] + json_files = [ + f + for f in json_files + if f.endswith('.json') or f.endswith('.jsonl') or f.endswith('.json.gz') or f.endswith('.jsonl.gz') + ] if len(json_files) == 0: - raise FileNotFoundError('No .json or .json.gz files found in folder.') + raise FileNotFoundError('No .json or .jsonl or json.gz or .jsonl.gz files found in folder.') else: - print(f'Found {len(json_files)} .json or .json.gz files.') + print(f'Found {len(json_files)} .json or .jsonl or json.gz or .jsonl.gz files.') else: assert os.path.exists(args.input), f'File does not exist: {args.input}' json_files = [args.input] From f89bca0ed5186597a7bc58944a8deb9efdbcc520 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 4 Jul 2024 21:30:16 -0400 Subject: [PATCH 05/13] [NeMo-UX] Add PEFT (#9490) * initial commit for PEFT in nemo2 * Apply isort and black reformatting Signed-off-by: cuichenx * address comments Signed-off-by: Chen Cui * make import easier Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * address comments Signed-off-by: Chen Cui * Update nemo/collections/llm/peft/lora.py Signed-off-by: Marc Romeyn * Some small fixes + adding more doc-strings * Apply isort and black reformatting Signed-off-by: marcromeyn * Adding ModelTransform callback * Apply isort and black reformatting Signed-off-by: marcromeyn * Fixing type-hint for model_transform * Apply isort and black reformatting Signed-off-by: marcromeyn * fix import Signed-off-by: Chen Cui * model transform for gemma llama Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * fix model transform Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * change lora target default to all linear modules Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * Small fix in mixtral * Apply isort and black reformatting Signed-off-by: marcromeyn * Integrating PEFT to the public-API + some fixes * Big refactor to allow to load adapter-states * Some fixes to support adapter_path * Apply isort and black reformatting Signed-off-by: marcromeyn * Disabling ckpt reloading when adapter_path is passed * Fix CLI * Apply isort and black reformatting Signed-off-by: marcromeyn * Remove commented-out code * Remove commented-out code * Remove un-used import * Fix callback imports * Apply isort and black reformatting Signed-off-by: marcromeyn * Fixing llm.pretrain * Some small fixes * Apply isort and black reformatting Signed-off-by: marcromeyn * Fix missing import + type-hint in finetune * Adding PreemptionCallback + some more tests * Apply isort and black reformatting Signed-off-by: marcromeyn * Clean up imports & clean up llm.api * Apply isort and black reformatting Signed-off-by: marcromeyn * Trying to fix failing tests * Remove __init__.py 2 * Apply isort and black reformatting Signed-off-by: marcromeyn * Fix failing test * Trying to fix last failing test --------- Signed-off-by: cuichenx Signed-off-by: Chen Cui Signed-off-by: Marc Romeyn Signed-off-by: marcromeyn Co-authored-by: cuichenx Co-authored-by: Marc Romeyn Co-authored-by: marcromeyn --- nemo/collections/llm/__init__.py | 6 +- nemo/collections/llm/api.py | 285 ++++++++++++++---- nemo/collections/llm/gpt/model/base.py | 3 + nemo/collections/llm/gpt/model/gemma.py | 4 +- nemo/collections/llm/gpt/model/llama.py | 4 +- nemo/collections/llm/gpt/model/mistral.py | 6 +- nemo/collections/llm/gpt/model/mixtral.py | 9 +- nemo/collections/llm/peft/__init__.py | 4 + nemo/collections/llm/peft/api.py | 11 + nemo/collections/llm/peft/lora.py | 123 ++++++++ .../megatron/adapters/parallel_adapters.py | 11 + nemo/lightning/__init__.py | 2 +- nemo/lightning/_strategy_lib.py | 41 ++- nemo/lightning/fabric/strategies.py | 43 +-- nemo/lightning/io/pl.py | 2 +- nemo/lightning/megatron_parallel.py | 3 +- nemo/lightning/nemo_logger.py | 6 +- nemo/lightning/pytorch/callbacks/__init__.py | 12 +- ...odel_checkpoint.py => model_checkpoint.py} | 7 +- .../pytorch/callbacks/model_transform.py | 98 ++++++ nemo/lightning/pytorch/callbacks/nsys.py | 31 +- nemo/lightning/pytorch/callbacks/peft.py | 261 ++++++++++++++++ .../lightning/pytorch/callbacks/preemption.py | 115 +++++++ nemo/lightning/pytorch/optim/base.py | 3 +- nemo/lightning/pytorch/strategies.py | 62 ++-- nemo/lightning/resume.py | 30 +- setup.py | 5 + tests/lightning/pytorch/callbacks/__init__.py | 0 .../pytorch/callbacks/test_model_transform.py | 48 +++ .../lightning/pytorch/callbacks/test_nsys.py | 195 ++++++++++++ .../lightning/pytorch/callbacks/test_peft.py | 68 +++++ .../pytorch/callbacks/test_preemption.py | 114 +++++++ tests/lightning/test_megatron_parallel.py | 8 +- 33 files changed, 1434 insertions(+), 186 deletions(-) create mode 100644 nemo/collections/llm/peft/__init__.py create mode 100644 nemo/collections/llm/peft/api.py create mode 100644 nemo/collections/llm/peft/lora.py rename nemo/lightning/pytorch/callbacks/{megatron_model_checkpoint.py => model_checkpoint.py} (98%) create mode 100644 nemo/lightning/pytorch/callbacks/model_transform.py create mode 100644 nemo/lightning/pytorch/callbacks/peft.py create mode 100644 nemo/lightning/pytorch/callbacks/preemption.py create mode 100644 tests/lightning/pytorch/callbacks/__init__.py create mode 100644 tests/lightning/pytorch/callbacks/test_model_transform.py create mode 100644 tests/lightning/pytorch/callbacks/test_nsys.py create mode 100644 tests/lightning/pytorch/callbacks/test_peft.py create mode 100644 tests/lightning/pytorch/callbacks/test_preemption.py diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 50c5c53f6533..83c0a3af48c0 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -4,8 +4,8 @@ except ImportError: pass -from nemo.collections.llm import tokenizer -from nemo.collections.llm.api import export_ckpt, import_ckpt, pretrain, train, validate +from nemo.collections.llm import peft, tokenizer +from nemo.collections.llm.api import export_ckpt, finetune, import_ckpt, pretrain, train, validate from nemo.collections.llm.gpt.data import ( DollyDataModule, FineTuningDataModule, @@ -98,6 +98,7 @@ "export_ckpt", "pretrain", "validate", + "finetune", "tokenizer", "mock", "squad", @@ -118,4 +119,5 @@ "gemma_7b", "code_gemma_2b", "code_gemma_7b", + "peft", ] diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 081b0f01b4c7..5c9703497597 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -1,11 +1,17 @@ +from copy import deepcopy from pathlib import Path -from typing import Callable, Optional +from typing import Any, Callable, Optional, Union import pytorch_lightning as pl from typing_extensions import Annotated from nemo.collections.llm.utils import Config, task -from nemo.lightning import AutoResume, MegatronStrategy, NeMoLogger, OptimizerModule, Trainer, io, teardown +from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform +from nemo.utils import logging + + +TokenizerType = Any @task(namespace="llm") @@ -16,7 +22,8 @@ def train( log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, optim: Optional[OptimizerModule] = None, - tokenizer: Optional[str] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, # TODO: Fix export export: Optional[str] = None, ) -> Path: """ @@ -30,42 +37,38 @@ def train( resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[str]): Tokenizer setting to be applied. Can be 'data' or 'model'. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns ------- Path: The directory path where training artifacts are saved. - Raises - ------ - ValueError: If the trainer's strategy is not MegatronStrategy. - Examples -------- - >>> model = MyModel() - >>> data = MyDataModule() - >>> trainer = Trainer(strategy=MegatronStrategy()) - >>> train(model, data, trainer, tokenizer='data', source='path/to/ckpt.ckpt', export='final.ckpt') + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ - _log = log or NeMoLogger() - app_state = _log.setup( - trainer, - resume_if_exists=getattr(resume, "resume_if_exists", False), - task_config=getattr(train, "__io__", None), + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, ) - if resume is not None: - resume.setup(model, trainer) - if optim: - optim.connect(model) - if tokenizer: # TODO: Improve this - _use_tokenizer(model, data, tokenizer) trainer.fit(model, data) - _log.teardown() - return app_state.exp_dir @@ -74,41 +77,152 @@ def pretrain( model: pl.LightningModule, data: pl.LightningDataModule, trainer: Trainer, - source: Optional[str] = None, - # export: Optional[str] = None + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, ) -> Path: - return train(model=model, data=data, trainer=trainer, tokenizer="data", source=source) + """ + Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization. + + This function is a wrapper around the `train` function, specifically configured for pretraining tasks. + Note, by default it will use the tokenizer from the model. + + Args: + model (pl.LightningModule): The model to be pretrained. + data (pl.LightningDataModule): The data module containing pretraining data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + + Returns: + Path: The directory path where pretraining artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.PretrainingDataModule(paths=[...], seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.pretrain(model, data, trainer) + PosixPath('/path/to/log_dir') + """ + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="data", + ) @task(namespace="llm") -def validate( +def finetune( model: pl.LightningModule, data: pl.LightningDataModule, trainer: Trainer, - tokenizer: Optional[str] = None, - source: Optional[str] = None, - export: Optional[str] = None, + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + peft: Optional[Union[PEFT, ModelTransform, Callable]] = None, ) -> Path: - if not isinstance(trainer.strategy, MegatronStrategy): - raise ValueError("Only MegatronStrategy is supported") + """ + Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT. - validate_kwargs = {} - run_dir = Path(trainer.logger.log_dir) - export_dir = run_dir / "export" + Note, by default it will use the tokenizer from the model. - if tokenizer: # TODO: Improve this - _use_tokenizer(model, data, tokenizer) - if source: - _add_ckpt_path(source, model, validate_kwargs) + Args: + model (pl.LightningModule): The model to be finetuned. + data (pl.LightningDataModule): The data module containing finetuning data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied. + + Returns: + Path: The directory path where finetuning artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + PosixPath('/path/to/log_dir') + """ - trainer.validate(model, data, **validate_kwargs) - trainer.save_checkpoint(export_dir) - if export: - teardown(trainer) - del trainer, model, data - export_ckpt(export_dir, export) + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="model", + model_transform=peft, + ) - return run_dir + +@task(namespace="llm") +def validate( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, +) -> Path: + """ + Validates a model using the specified data and trainer, with optional logging, resuming, and model transformations. + + Args: + model (pl.LightningModule): The model to be validated. + data (pl.LightningDataModule): The data module containing validation data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume from a checkpoint for validation. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + from the model will be used. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. + + Returns: + Path: The directory path where validation artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> validate(model, data, trainer, tokenizer="data") + PosixPath('/path/to/log_dir') + """ + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + trainer.validate(model, data) + + return app_state.exp_dir @task(name="import", namespace="llm") @@ -136,28 +250,67 @@ def export_ckpt( return io.export_ckpt(path, target, output_path, overwrite, load_connector) -def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: str) -> None: +def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None: if tokenizer == "data": - model.tokenizer = data.tokenizer - if hasattr(model, "__io__"): - model.__io__.tokenizer = data.tokenizer + _set_with_io(model, "tokenizer", data.tokenizer) elif tokenizer == "model": - data.tokenizer = model.tokenizer - if hasattr(data, "__io__"): - data.__io__.tokenizer = model.tokenizer + _set_with_io(data, "tokenizer", model.tokenizer) + else: + try: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + if isinstance(tokenizer, TokenizerSpec): + _set_with_io(model, "tokenizer", tokenizer) + _set_with_io(data, "tokenizer", tokenizer) + else: + raise ValueError(f"Expected TokenizerSpec or 'data' or 'model', got: {tokenizer}") + except ImportError: + raise ValueError("TokenizerSpec is not available") -def _add_ckpt_path(source, model, kwargs) -> None: - if io.is_distributed_ckpt(source): - kwargs["ckpt_path"] = source - else: - kwargs["ckpt_path"] = model.import_ckpt(source) +def _setup( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Optional[NeMoLogger], + resume: Optional[AutoResume], + optim: Optional[OptimizerModule], + tokenizer: Optional[TokenizerType], + model_transform: Optional[Union[PEFT, ModelTransform, Callable]], +) -> Any: # Return type is Any because app_state's type is not specified + _log = log or NeMoLogger() + if resume and resume.adapter_path and _log.ckpt: + logging.info("Disabling try_restore_best_ckpt restoration for adapters") + _log.ckpt.try_restore_best_ckpt = False + + app_state = _log.setup( + trainer, + resume_if_exists=getattr(resume, "resume_if_exists", False), + task_config=getattr(train, "__io__", None), + ) + if resume is not None: + resume.setup(model, trainer) + + if optim: + optim.connect(model) + if tokenizer: # TODO: Improve this + _use_tokenizer(model, data, tokenizer) + + if model_transform: + _set_with_io(model, "model_transform", model_transform) + + # Add ModelTransform callback to Trainer if needed + if getattr(model, "model_transform", None): + if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks): + if isinstance(model_transform, ModelTransform): + trainer.callbacks.append(model_transform) + else: + trainer.callbacks.append(ModelTransform()) + + return app_state -def _save_config_img(*args, **kwargs): - try: - from nemo_sdk.utils import save_config_img - save_config_img(*args, **kwargs) - except ImportError: - pass +def _set_with_io(obj, attr, value): + setattr(obj, attr, value) + if hasattr(obj, "__io__") and hasattr(value, "__io__"): + setattr(obj.__io__, attr, deepcopy(value.__io__)) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 9b7f4e4ab0c8..28a0eed52a5f 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -6,6 +6,7 @@ import torch.distributed from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn from nemo.collections.llm import fn from nemo.lightning import get_vocab_size, io @@ -117,12 +118,14 @@ def __init__( # TODO: Add transformer_layer_spec when we update mcore optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): super().__init__() self.config = config self.tokenizer = tokenizer self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method + self.model_transform = model_transform def configure_model(self) -> None: if not hasattr(self, "module"): diff --git a/nemo/collections/llm/gpt/model/gemma.py b/nemo/collections/llm/gpt/model/gemma.py index 348cad255876..6493bb0dfad7 100644 --- a/nemo/collections/llm/gpt/model/gemma.py +++ b/nemo/collections/llm/gpt/model/gemma.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Annotated, Callable, Optional import torch +from torch import nn from nemo.collections.llm.fn.activation import openai_gelu from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel @@ -68,8 +69,9 @@ def __init__( config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer) + super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) @io.model_importer(GemmaModel, "hf") diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 94cbd99acf90..c7add828b7f4 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F +from torch import nn from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config @@ -103,8 +104,9 @@ def __init__( config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer) + super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) @io.model_importer(LlamaModel, "hf") diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index 274a761fe5b6..d1049cfe77ce 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -5,6 +5,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from torch import nn from typing_extensions import Annotated from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel @@ -46,8 +47,11 @@ def __init__( config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or MistralConfig7B(), optim=optim, tokenizer=tokenizer) + super().__init__( + config or MistralConfig7B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) @io.model_importer(MistralModel, "hf") diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 7d757479d27a..af1b73dd9109 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -4,15 +4,17 @@ import torch import torch.nn.functional as F +from torch import nn from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.lightning import io, teardown from nemo.lightning.pytorch.optim import OptimizerModule if TYPE_CHECKING: - from transformers import MistralConfig, MistralForCausalLM + from transformers import MixtralForCausalLM from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @dataclass @@ -53,8 +55,11 @@ def __init__( config: Optional[MixtralConfig8x7B] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or MixtralConfig8x7B(), optim=optim, tokenizer=tokenizer) + super().__init__( + config or MixtralConfig8x7B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) @io.model_importer(MixtralModel, ext="hf") diff --git a/nemo/collections/llm/peft/__init__.py b/nemo/collections/llm/peft/__init__.py new file mode 100644 index 000000000000..69855f6f9c53 --- /dev/null +++ b/nemo/collections/llm/peft/__init__.py @@ -0,0 +1,4 @@ +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.peft.lora import LoRA + +__all__ = ["LoRA", "gpt_lora"] diff --git a/nemo/collections/llm/peft/api.py b/nemo/collections/llm/peft/api.py new file mode 100644 index 000000000000..dc8fc76c752e --- /dev/null +++ b/nemo/collections/llm/peft/api.py @@ -0,0 +1,11 @@ +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.utils import factory +from nemo.lightning.pytorch.callbacks.peft import PEFT + + +@factory +def gpt_lora() -> PEFT: + return LoRA() + + +__all__ = ["gpt_lora"] diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py new file mode 100644 index 000000000000..913144d1bf5f --- /dev/null +++ b/nemo/collections/llm/peft/lora.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass, field +from typing import List, Literal + +from megatron.core import parallel_state +from torch import nn + +from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper +from nemo.utils import logging + + +class AdapterParallelAdd(AdapterWrapper): + """An adapter wrapper that adds the output of the adapter to the output of the wrapped module. + + This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques + where the adapter's output is added to the main module's output. It extends the AdapterWrapper + class to provide a specific implementation of the forward method. + """ + + def forward(self, x): + linear_output, bias = self.to_wrap(x) + if isinstance(linear_output, tuple) and len(linear_output) == 2: + linear_output, layernorm_output = linear_output + adapter_output = self.adapter(layernorm_output) + else: + adapter_output = self.adapter(x) + return linear_output + adapter_output, bias + + +@dataclass +class LoRA(PEFT): + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections + in self-attention modules. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention modules. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + dim (int): Dimension of the low-rank projection space. Defaults to 32. + alpha (int): Weighting factor for the low-rank projection. Defaults to 32. + dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. + dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. + Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'post'. + + Example: + -------- + >>> from nemo.collections import llm + >>> lora = llm.peft.LoRA(target_modules=['linear_qkv', 'linear_proj'], dim=32) + >>> model = llm.Mistral7BModel(model_transform=lora) + >>> # (set up trainer and data) + >>> trainer.fit(model, data) + + References: + ----------- + Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., & Chen, W. (2021). + LoRA: Low-Rank Adaptation of Large Language Models. arXiv preprint arXiv:2106.09685. + https://arxiv.org/abs/2106.09685 + + ) + """ + + target_modules: List[str] = field( + default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] + ) + dim: int = 32 + alpha: int = 32 + dropout: float = 0.0 + dropout_position: Literal['pre', 'post'] = 'post' + + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Applies LoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply LoRA to. + name (str, optional): Name of the module (if applicable). Defaults to None. + prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with LoRA applied, or the original module if not a target. + """ + from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + if name in self.target_modules: + # m.in_features and m.out_features are divided by tp_size already, + # but in_features and out_features passed to ParallelLinearAdapter are not. + if name in ['linear_qkv', 'linear_fc1']: + # Column Parallel Linear + input_is_parallel = False + in_features = m.in_features + out_features = m.out_features * tp_size + else: # name in ['linear_proj', 'linear_fc2'] + # Row Parallel Linear + input_is_parallel = True + in_features = m.in_features * tp_size + out_features = m.out_features + + logging.info(f"Adding lora to: {prefix}.{name}") + adapter = ParallelLinearAdapter( + in_features, + out_features, + self.dim, + activation='identity', + norm_position=None, + norm_type=None, + column_init_method="normal", + row_init_method="zero", + gather_output=False, + input_is_parallel=input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(m, "config", None), + alpha=self.alpha, + ) + return AdapterParallelAdd(m, adapter) + return m diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 21dace008877..9ab1da7136a1 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -24,6 +24,7 @@ import torch.nn as nn import torch.nn.init as init +from megatron.core.dist_checkpointing.mapping import ShardedStateDict from nemo.collections.common.parts.adapter_modules import AdapterModuleUtil from nemo.collections.common.parts.utils import activation_registry from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu @@ -322,6 +323,16 @@ def forward(self, x): return x + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + sharded_state_dict = {} + sharded_state_dict.update(self.linear_in.sharded_state_dict(f"{prefix}linear_in.", sharded_offsets, metadata)) + sharded_state_dict.update( + self.linear_out.sharded_state_dict(f"{prefix}linear_out.", sharded_offsets, metadata) + ) + return sharded_state_dict + class _All2AllHp2Sp(torch.autograd.Function): """ diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index d414376d8168..e9674ed1e212 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -14,7 +14,7 @@ from nemo.lightning.fabric.plugins import FabricMegatronMixedPrecision from nemo.lightning.fabric.strategies import FabricMegatronStrategy from nemo.lightning.nemo_logger import NeMoLogger -from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, lr_scheduler from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index cb74b42a74c8..11e89a468c76 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generator, Mapping, Optional, Protocol, TypeVar import torch from torch import nn @@ -472,3 +472,42 @@ def get_safe(param_id): optim_state_to_sharding_state(optimizer_state_dict["optimizer"], id_to_sharded_param_map) return optimizer_state_dict + + +def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + from megatron.core import parallel_state + + for index, module in enumerate(megatron_parallel): + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + if "state_dict" in checkpoint: + checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"] + else: + checkpoint_state_dict = checkpoint[f"model_{index}"] + else: + if "state_dict" in checkpoint: + checkpoint_state_dict = checkpoint["state_dict"] + else: + checkpoint_state_dict = checkpoint + + n_nesting = 0 + mcore_model = megatron_parallel.module + while hasattr(mcore_model, "module"): + mcore_model = mcore_model.module + n_nesting += 1 + + _state_dict = {} + for key, value in checkpoint_state_dict.items(): + # Count the number of "module." at the start of the key + count, _key = 0, key + while _key.startswith("module."): + _key = _key[len("module.") :] + count += 1 + + # Adjust the number of "module." prefixes + if count < n_nesting: + to_add = "module." * (n_nesting - count) + _state_dict[f"{to_add}{key}"] = value + elif count > n_nesting: + to_remove = "module." * (count - n_nesting) + _state_dict[key[len(to_remove) :]] = value + module.load_state_dict(_state_dict, strict=strict) diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index a53cee1c75e8..a662386a9119 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -296,48 +296,7 @@ def load_checkpoint( def load_module_state_dict( self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: - from megatron.core import parallel_state - - for index, p_module in enumerate(module): - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - if "state_dict" in state_dict: - checkpoint_state_dict = state_dict["state_dict"][f"model_{index}"] - else: - checkpoint_state_dict = state_dict[f"model_{index}"] - else: - if "state_dict" in state_dict: - checkpoint_state_dict = state_dict["state_dict"] - else: - checkpoint_state_dict = state_dict - - mcore_model = p_module.module - while hasattr(mcore_model, "module"): - mcore_model = mcore_model.module - - current = module[0] - n_nesting = 0 - while current != mcore_model: - current = current.module - n_nesting += 1 - - _state_dict = {} - for key, value in checkpoint_state_dict.items(): - # Count the number of "module." at the start of the key - count, _key = 0, key - while _key.startswith("module."): - _key = _key[len("module.") :] - count += 1 - - # Adjust the number of "module." prefixes - if count < n_nesting: - to_add = "module." * (n_nesting - count) - _state_dict[f"{to_add}{key}"] = value - elif count > n_nesting: - to_remove = "module." * (count - n_nesting) - _state_dict[key[len(to_remove) :]] = value - checkpoint_state_dict = _state_dict - - p_module.load_state_dict(checkpoint_state_dict, strict=strict) + _strategy_lib.load_model_state_dict(module, state_dict, strict=strict) @contextmanager def megatron_context(self) -> Generator[None, None, None]: diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index b582e4a6b7dd..51cd639f4dc3 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -46,7 +46,7 @@ def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]: return extra -class MegatronCheckpointIO(CheckpointIO): +class MegatronCheckpointIO(CheckpointIO, IOMixin): """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, common for most use cases. diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 919224d5b9f6..386b9d5070f9 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -12,6 +12,7 @@ Iterable, Iterator, List, + Mapping, Optional, Protocol, Sequence, @@ -525,7 +526,7 @@ def sharded_state_dict(self, prefix: str = "") -> Dict[str, Any]: # virtual pipline rank must be set so that GPTModel returns the correct sharded state dict parallel_state.set_virtual_pipeline_model_parallel_rank(index) module_sharded_state_dict = self._module_sharded_state_dict(module) - sharded_state_dict[f"megatron_module_{index}"] = module_sharded_state_dict + sharded_state_dict[f"model_{index}"] = module_sharded_state_dict else: module_sharded_state_dict = self._module_sharded_state_dict(module) sharded_state_dict.update(module_sharded_state_dict) diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index efed77663876..5ed783fdbefe 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -11,13 +11,14 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger +from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.callbacks import ModelCheckpoint from nemo.utils import logging from nemo.utils.app_state import AppState @dataclass -class NeMoLogger: +class NeMoLogger(IOMixin): """Logger for NeMo runs. Args: @@ -219,6 +220,3 @@ def _setup_files_to_move(self, log_dir, app_state): app_state.files_to_move = files_to_move app_state.files_to_copy = self.files_to_copy - - def teardown(self): - pass diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index 1525ab21b835..ee0e777d739e 100644 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -1,7 +1,9 @@ -from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback +from nemo.lightning.pytorch.callbacks.peft import PEFT +from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback from nemo.lightning.pytorch.callbacks.progress import MegatronProgressBar -__all__ = [ - "MegatronProgressBar", - "ModelCheckpoint", -] + +__all__ = ["ModelCheckpoint", "ModelTransform", "PEFT", "NsysCallback", "MegatronProgressBar", "PreemptionCallback"] diff --git a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py similarity index 98% rename from nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py rename to nemo/lightning/pytorch/callbacks/model_checkpoint.py index 4c0da66828a7..d0a1585f6293 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -51,11 +51,13 @@ def __init__( save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation enable_nemo_ckpt_io: bool = True, + try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" self.enable_nemo_ckpt_io = enable_nemo_ckpt_io + self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -266,8 +268,9 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): self.best_model_path = self.best_model_path.split('.ckpt')[0] - self.best_model_path = trainer.strategy.broadcast(self.best_model_path) - trainer._checkpoint_connector.restore(self.best_model_path) + if self.try_restore_best_ckpt: + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) + trainer._checkpoint_connector.restore(self.best_model_path) def _del_model_without_trainer(self, filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py new file mode 100644 index 000000000000..68b3db16f473 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -0,0 +1,98 @@ +from functools import wraps +from typing import Any, Callable, Optional, TypeVar + +import pytorch_lightning as pl +from torch import nn + +from nemo.lightning.io.mixin import IOMixin +from nemo.utils import logging + + +class ModelTransform(pl.Callback, IOMixin): + """ + A PyTorch Lightning callback that applies a model transformation function at the start of fitting or validation. + + This callback is designed to apply a transformation to the model when fitting or validation begins. + This design allows for loading the original checkpoint first and then applying the transformation, + which is particularly useful for techniques like Parameter-Efficient Fine-Tuning (PEFT). + + The transformation function is expected to be defined on the LightningModule + as an attribute called 'model_transform'. + + Key Features: + - Applies transformation at the start of fit or validation, not during initialization. + - Allows loading of original checkpoints before transformation. + - Supports PEFT and similar techniques that modify model structure. + + Example: + >>> class MyLightningModule(pl.LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.model = SomeModel() + ... self.model_transform = lambda m: SomePEFTMethod()(m) + ... + >>> model = MyLightningModule() + >>> # Load original checkpoint here if needed + >>> model.load_state_dict(torch.load('original_checkpoint.pth')) + >>> trainer = pl.Trainer(callbacks=[ModelTransform()]) + >>> # The model will be transformed when trainer.fit() or trainer.validate() is called + >>> trainer.fit(model) + + Note: + The transformation is applied only once, at the start of fitting or validation, + whichever comes first. This ensures that the model structure is modified before + any forward passes or parameter updates occur, but after the original weights + have been loaded. + """ + + def __init__(self): + super().__init__() + self.model_transform: Optional[Callable[[nn.Module], nn.Module]] = None + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + logging.info(f"Setting up ModelTransform for stage: {stage}") + + if hasattr(pl_module, 'model_transform'): + logging.info("Found model_transform attribute on pl_module") + self.model_transform = _call_counter(pl_module.model_transform) + pl_module.model_transform = self.model_transform + logging.info(f"Set model_transform to: {self.model_transform}") + else: + logging.info("No model_transform attribute found on pl_module") + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._maybe_apply_transform(trainer) + + def _maybe_apply_transform(self, trainer): + if self._needs_to_call: + self.model_transform(trainer.model) + + @property + def _needs_to_call(self) -> bool: + return self.model_transform and self.model_transform.__num_calls__ == 0 + + +T = TypeVar('T', bound=Callable[..., Any]) + + +def _call_counter(func: T) -> T: + """ + A decorator that counts the number of times a function is called. + + This decorator wraps a function and adds a '__num_calls__' attribute to it, + which is incremented each time the function is called. + + Args: + func (Callable): The function to be wrapped. + + Returns: + Callable: The wrapped function with a call counter. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + wrapper.__num_calls__ += 1 + return func(*args, **kwargs) + + wrapper.__num_calls__ = 0 + return wrapper # type: ignore diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index c18722a607b4..d24d7fd974be 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -9,6 +9,26 @@ class NsysCallback(Callback, IOMixin): + """ + A PyTorch Lightning callback for NVIDIA Nsight Systems (Nsys) profiling. + + This callback enables profiling of specific steps during training using NVIDIA Nsys. + It allows for precise control over when profiling starts and ends, which ranks are profiled, + and whether to generate detailed shape information. + + More info about nsys can be found [here](https://developer.nvidia.com/nsight-systems). + + Args: + start_step (int): Global batch to start profiling + end_step (int): Global batch to end profiling + ranks (List[int]): Global rank IDs to profile + gen_shape (bool): Generate model and kernel details including input shapes + + Example: + >>> callback = NsysCallback(start_step=100, end_step=200, ranks=[0, 1], gen_shape=True) + >>> trainer = Trainer(callbacks=[callback]) + """ + def __init__( self, start_step: int, @@ -16,13 +36,6 @@ def __init__( ranks: List[int] = [0], gen_shape: bool = False, ): - """ - Args: - start_step (int): Global batch to start profiling - end_step (int): Global batch to end profiling - ranks (List[int]): Global rank IDs to profile - gen_shape (bool): Generate model and kernel details including input shapes - """ assert type(start_step) == int, f'Nsys start_step must be of type int. Found: {type(start_step)}' self._nsys_profile_start_step = start_step @@ -54,6 +67,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Opt torch.cuda.cudart().cudaProfilerStart() if self._nsys_profile_gen_shape: torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + else: + torch.autograd.profiler.emit_nvtx().__enter__() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: """PyTorch Lightning hook: @@ -63,7 +78,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) device = trainer.strategy.root_device if device.type == 'cuda': - print(f'batch idx: {batch_idx}') if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py new file mode 100644 index 000000000000..26325bf549d0 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -0,0 +1,261 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple + +import pytorch_lightning as pl +import torch.nn as nn +from lightning_fabric.utilities.types import _PATH +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO +from typing_extensions import override + +from nemo.lightning.io.pl import ckpt_to_dir +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.utils import logging + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + + +_ADAPTER_META_FILENAME = "adapter_metadata.json" + + +class PEFT(ABC, ModelTransform): + """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. + + This class defines the interface for PEFT methods, which are used to fine-tune + large language models efficiently by modifying only a small subset of the model's + parameters. + + Example: + class MyPEFT(PEFT): + def transform(self, module, name=None, prefix=None): + # Implement the transform logic + pass + + + peft = MyPEFT() + peft_model = LargeLanguageModel(model_transform=peft) + """ + + @abstractmethod + def transform(self, module, name=None, prefix=None): + """Transform a single module according to the PEFT method. + + This method is called for each module in the model during the PEFT application process. + It should be implemented by subclasses to define how individual modules are transformed + for the specific PEFT technique. + + Args: + module (nn.Module): The individual module to be transformed. + name (Optional[str]): The name of the module within the model structure. Defaults to None. + prefix (Optional[str]): A prefix to be added to the module name, typically used for + nested modules. Defaults to None. + + Returns: + nn.Module: The transformed module. This can be the original module with modifications, + a new module replacing the original, or the original module if no + transformation is needed for this specific module. + + Note: + This method is automatically called for each module in the model when the PEFT + instance is applied to the model using the __call__ method. + """ + raise NotImplementedError("The transform method should be implemented by subclasses.") + + def __call__(self, model: nn.Module) -> nn.Module: + """Apply the PEFT method to the entire model. + + This method freezes the model parameters and walks through the model + structure, applying the transform method to each module. + + Args: + model (nn.Module): The model to be fine-tuned. + + Returns: + nn.Module: The transformed model with PEFT applied. + """ + + model.freeze() + model.walk(self.transform) + + return model + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + super().setup(trainer, pl_module, stage=stage) + + self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io) + trainer.strategy._checkpoint_io = self.wrapped_io + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + needs_to_call = self._needs_to_call + self._maybe_apply_transform(trainer) + + # Check if we need to load the adapters + if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None: + logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}") + adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path) + trainer.strategy.load_model_state_dict(adapter_state, strict=False) + + def on_load_checkpoint( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] + ) -> None: + pl_module.strict_loading = False + + +class AdapterWrapper(nn.Module): + """Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT). + + This class wraps a module and its associated adapter, providing methods for + managing the state dictionaries of both the main module and the adapter. It does not + implement the forward method, which must be implemented by concrete subclasses. + + Attributes: + to_wrap (nn.Module): The main module to be wrapped. + adapter (nn.Module): The adapter module to be applied. + + Note: + This class is abstract and cannot be instantiated directly. Subclasses must + implement the forward method. + + Example: + class AdapterParallelAdd(AdapterWrapper): + def __init__(self, to_wrap, adapter): + super().__init__(to_wrap, adapter) + + def forward(self, x): + return self.to_wrap(x) + self.adapter(x) + + main_module = nn.Linear(100, 100) + adapter = nn.Linear(100, 100) + parallel_adapter = AdapterParallelAdd(main_module, adapter) + """ + + def __init__(self, to_wrap: nn.Module, adapter: nn.Module): + super(AdapterWrapper, self).__init__() + self.to_wrap = to_wrap + self.adapter = adapter + + def state_dict(self, destination=None, prefix='', keep_vars=False): + """Retrieve the state dictionary of the wrapped module and adapter. + + This method overrides the default state_dict behavior to include both + the main module's state and the adapter's state under a special 'adapters' key. + + Args: + destination (Optional[dict]): A dictionary to store the state. If None, a new + dictionary is created. Defaults to None. + prefix (str): A prefix added to parameter and buffer names. Defaults to ''. + keep_vars (bool): If True, returns variables instead of tensor values. + Defaults to False. + + Returns: + dict: The state dictionary containing both the main module and adapter states. + """ + + if destination is None: + destination = {} + + # Get state dict of the main module + main_state_dict = self.to_wrap.state_dict(destination, prefix, keep_vars) + + # Store adapter state dict under the special "adapters" key in the destination dict + adapter_state_dict = self.adapter.state_dict(None, prefix, keep_vars) + destination[f'{prefix}adapters'] = adapter_state_dict + return main_state_dict + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> "ShardedStateDict": + """Retrieve the sharded state dictionary of the wrapped module and adapter. + + This method is used for distributed checkpointing, combining the sharded states + of both the main module and the adapter. + + Args: + prefix (str): A prefix added to parameter and buffer names. Defaults to ''. + sharded_offsets (Tuple[Tuple[int, int, int]]): Offsets for sharded parameters. + Defaults to an empty tuple. + metadata (Optional[dict]): Additional metadata for the sharded state. + Defaults to None. + + Returns: + ShardedStateDict: The combined sharded state dictionary. + """ + sharded_state_dict = {} + sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata)) + sharded_state_dict.update(self.adapter.sharded_state_dict(f"{prefix}adapter.", sharded_offsets, metadata)) + return sharded_state_dict + + def load_state_dict(self, state_dict, strict=True): + """Load a state dictionary into the wrapped module and adapter. + + This method overrides the default load_state_dict behavior to handle + loading states for both the main module and the adapter. + + Args: + state_dict (dict): The state dictionary to load. + strict (bool): Whether to strictly enforce that the keys in state_dict + match the keys returned by this module's state_dict() + function. Defaults to True. + """ + # Check if the 'adapters' key is present in the state_dict + if 'adapters' in state_dict: + adapter_state_dict = state_dict.pop('adapters') + else: + adapter_state_dict = {} + + # Load the main module state dict + self.to_wrap.load_state_dict(state_dict, strict) + + # Load the adapter module state dict if present + if adapter_state_dict: + self.adapter.load_state_dict(adapter_state_dict, strict) + + +class WrappedAdapterIO(_WrappingCheckpointIO): + model_ckpt_path: Optional[Path] = None + adapter_ckpt_path: Optional[Path] = None + + @override + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + assert self.checkpoint_io is not None + + key = "sharded_state_dict" if "sharded_state_dict" in checkpoint else "state_dict" + checkpoint[key] = dict(filter(lambda x: ".adapter." in x[0], checkpoint[key].items())) + + self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) + + from nemo.utils.get_rank import is_global_rank_zero + + if is_global_rank_zero(): + metadata = {"model_ckpt_path": str(self.model_ckpt_path)} + adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + with open(adapter_meta_path, "w") as f: + json.dump(metadata, f) + + @override + def load_checkpoint( + self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + ) -> Dict[str, Any]: + assert self.checkpoint_io is not None + + adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + if getattr(path, "adapter_path", None): + self.model_ckpt_path = path + self.adapter_ckpt_path = path.adapter_path + elif adapter_meta_path.exists(): + with open(adapter_meta_path, "r") as f: + metadata = json.load(f) + self.model_ckpt_path = Path(metadata['model_ckpt_path']) + self.adapter_ckpt_path = path + else: + self.model_ckpt_path = path + + # Note: this will include the Trainer-state of the model-checkpoint + model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location) + + return model_ckpt diff --git a/nemo/lightning/pytorch/callbacks/preemption.py b/nemo/lightning/pytorch/callbacks/preemption.py new file mode 100644 index 000000000000..7f1dd94256d2 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/preemption.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import signal +from typing import Optional + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.utils import logging + + +class PreemptionCallback(Callback): + """ + PreemptionCallback checks for preemption during training at the end of every step. + Upon preemption, it signals the trainer to stop gracefully. + + Args: + sig (int, optional): The signal to listen for. Defaults to signal.SIGTERM. + + Example: + >>> from nemo.lightning.pytorch.callbacks import PreemptionCallback + >>> callback = PreemptionCallback() + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, sig: Optional[int] = None): + self.sig = sig if sig is not None else signal.SIGTERM + self._interrupted = False + self._handler_context = None + self._preemption_supported = None + + def on_train_start(self, trainer: Trainer, pl_module) -> None: + if self.preemption_supported: + self._handler_context = self._preemption_handler() + self._handler_context.__enter__() + + def on_train_batch_start(self, trainer: Trainer, pl_module, batch, batch_idx: int) -> None: + if not self.preemption_supported: + self._preemption_supported = self._check_preemption_support() + if self.preemption_supported: + self._handler_context = self._preemption_handler() + self._handler_context.__enter__() + + def on_train_end(self, trainer: Trainer, pl_module) -> None: + if self._handler_context: + self._handler_context.__exit__(None, None, None) + + def on_train_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx: int) -> None: + if self.interrupted: + logging.info("Preemption detected, signaling trainer to stop") + trainer.should_stop = True + + def on_exception(self, trainer: Trainer, pl_module, exception: BaseException) -> None: + if isinstance(exception, PreemptionException): + logging.info("Handling PreemptionException") + trainer.should_stop = True + + @contextlib.contextmanager + def _preemption_handler(self): + if not self.preemption_supported: + logging.warning("Preemption requires torch distributed to be initialized, preemption may be disabled") + yield + return + + original_handler = signal.getsignal(self.sig) + + def master_handler(signum, frame): + logging.info(f"Received signal {signum}, initiating graceful stop") + self._interrupted = True + raise PreemptionException("Preemption signal received") + + def ignoring_handler(signum, frame): + logging.debug(f"Received signal {signum} on non-master rank, ignoring") + + try: + private_rank = torch.distributed.get_rank() + signal.signal(self.sig, master_handler if private_rank == 0 else ignoring_handler) + yield + finally: + signal.signal(self.sig, original_handler) + + @property + def preemption_supported(self) -> bool: + if self._preemption_supported is None: + self._preemption_supported = self._check_preemption_support() + return self._preemption_supported + + def _check_preemption_support(self) -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @property + def interrupted(self) -> bool: + if not self.preemption_supported: + return False + interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32) + torch.distributed.broadcast(interrupted, 0) + return bool(interrupted.item()) + + +class PreemptionException(Exception): + """Custom exception for preemption events.""" diff --git a/nemo/lightning/pytorch/optim/base.py b/nemo/lightning/pytorch/optim/base.py index 88a77328ef9b..8e857a156649 100644 --- a/nemo/lightning/pytorch/optim/base.py +++ b/nemo/lightning/pytorch/optim/base.py @@ -1,5 +1,6 @@ import types from abc import ABC, abstractmethod +from copy import deepcopy from typing import List, Optional import pytorch_lightning as L @@ -134,7 +135,7 @@ def custom_configure_optimizers(lightning_module_self, megatron_parallel=None): if hasattr(self, "__io__") and hasattr(model, "__io__"): if hasattr(model.__io__, "optim"): - model.__io__.optim = self.__io__ + model.__io__.optim = deepcopy(self.__io__) @abstractmethod def optimizers(self, model) -> List[Optimizer]: diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 99e7245d60dd..0f6dc89a7076 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -33,7 +33,7 @@ from nemo.lightning import _strategy_lib, io from nemo.lightning.io.pl import MegatronCheckpointIO from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel, _ModuleStepFunction -from nemo.lightning.pytorch.callbacks import MegatronProgressBar +from nemo.lightning.pytorch.callbacks import MegatronProgressBar, ModelTransform if TYPE_CHECKING: from nemo.lightning.pytorch.plugins.data_sampler import DataSampler @@ -106,9 +106,9 @@ def __init__( **kwargs, ) -> None: super().__init__( - parallel_devices, - cluster_environment, - checkpoint_io, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, find_unused_parameters=find_unused_parameters, **kwargs, ) @@ -193,6 +193,18 @@ def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers) self.setup_precision_plugin() + if getattr(self.lightning_module, "model_transform", None): + # Ensure the ModelTransform callback is pass to the trainer. + # Callback.setup() is called before the current Strategy.setup(), so we can + # only perform a check here; adding the callback here would not be sufficient + if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks): + raise ValueError( + "You specified a model_transform function in the model, but no" + "ModelTransform callback was found in the trainer. " + "Please initialize the trainer with " + "`trainer = Trainer(..., callbacks=[ModelTransform()])`" + ) + if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1: # TODO: log here trainer.num_sanity_val_steps = 0 @@ -522,53 +534,21 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.megatron_parallel is not None - from megatron.core import parallel_state - for index, module in enumerate(self.megatron_parallel): - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] - else: - checkpoint_state_dict = checkpoint['state_dict'] - - mcore_model = self.lightning_module.module - while hasattr(mcore_model, "module"): - mcore_model = mcore_model.module - - current = self.model[0] - n_nesting = 0 - while current != mcore_model: - current = current.module - n_nesting += 1 - - _state_dict = {} - for key, value in checkpoint_state_dict.items(): - # Count the number of "module." at the start of the key - count, _key = 0, key - while _key.startswith("module."): - _key = _key[len("module.") :] - count += 1 - - # Adjust the number of "module." prefixes - if count < n_nesting: - to_add = "module." * (n_nesting - count) - _state_dict[f"{to_add}{key}"] = value - elif count > n_nesting: - to_remove = "module." * (count - n_nesting) - _state_dict[key[len(to_remove) :]] = value - checkpoint_state_dict = _state_dict - - module.load_state_dict(checkpoint_state_dict, strict=strict) + _strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict) @property @override def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = MegatronCheckpointIO() - elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): - self._checkpoint_io.checkpoint_io = MegatronCheckpointIO() return self._checkpoint_io + @checkpoint_io.setter + def checkpoint_io(self, io: CheckpointIO) -> None: + self._checkpoint_io = io + def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]: for fn_name in [f"{step_type}_data_step", "data_step"]: if hasattr(self.lightning_module, fn_name): diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index f762d345ed3b..fc2e21eb37fd 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -1,16 +1,24 @@ -from pathlib import Path +import os +from pathlib import Path, PosixPath, WindowsPath from typing import Optional, Union import lightning_fabric as fl import pytorch_lightning as pl from nemo.lightning import io +from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.model_utils import uninject_model_parallel_rank +# Dynamically inherit from the correct Path subclass based on the operating system. +if os.name == 'nt': + BasePath = WindowsPath +else: + BasePath = PosixPath -class Resume: + +class Resume(IOMixin): def nemo_path(self, model) -> Optional[Path]: raise NotImplementedError @@ -34,6 +42,7 @@ def __init__( path: Optional[str] = None, ## old resume_from_checkpoint dirpath: Optional[str] = None, ## optional path to checkpoint directory import_path: Optional[str] = None, ## for importing from hf or other checkpoint formats + adapter_path: Optional[str] = None, resume_if_exists: bool = False, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, @@ -66,6 +75,7 @@ def __init__( self.path = path self.dirpath = dirpath self.import_path = import_path + self.adapter_path = adapter_path self.resume_if_exists = resume_if_exists self.resume_past_end = resume_past_end self.resume_ignore_no_checkpoint = resume_ignore_no_checkpoint @@ -76,7 +86,10 @@ def nemo_path(self, model=None) -> Optional[Path]: if self.import_path: if model is None: raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.") - return model.import_ckpt(self.import_path) + output = model.import_ckpt(self.import_path) + if self.adapter_path: + return AdapterPath(output, adapter_path=Path(self.adapter_path)) + return output ### refactored from exp_manager checkpoint = None @@ -131,6 +144,17 @@ def nemo_path(self, model=None) -> Optional[Path]: checkpoint = last_checkpoints[0] if checkpoint: + if self.adapter_path: + return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) return Path(checkpoint) return None + + +class AdapterPath(BasePath): + adapter_path: Optional[Path] + + def __new__(cls, *args, adapter_path: Optional[Path] = None, **kwargs): + output = super().__new__(cls, *args, **kwargs) + output.adapter_path = adapter_path + return output diff --git a/setup.py b/setup.py index 6c82ef803174..292be13e65df 100644 --- a/setup.py +++ b/setup.py @@ -286,4 +286,9 @@ def finalize_options(self): keywords=__keywords__, # Custom commands. cmdclass={'style': StyleCommand}, + entry_points={ + "sdk.factories": [ + "llm = nemo.collections.llm", + ], + }, ) diff --git a/tests/lightning/pytorch/callbacks/__init__.py b/tests/lightning/pytorch/callbacks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/lightning/pytorch/callbacks/test_model_transform.py b/tests/lightning/pytorch/callbacks/test_model_transform.py new file mode 100644 index 000000000000..9894f7d7bc58 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_model_transform.py @@ -0,0 +1,48 @@ +import pytest +import pytorch_lightning as pl +from torch import nn + +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform + + +class TestModelTransformCallback: + @pytest.fixture + def callback(self): + return ModelTransform() + + @pytest.fixture + def pl_module(self): + return MockLightningModule() + + @pytest.fixture + def trainer(self): + return pl.Trainer() + + def test_setup_stores_transform(self, callback, pl_module, trainer, caplog): + callback.setup(trainer, pl_module, 'fit') + + assert callback.model_transform is not None, "callback.model_transform should be set after setup" + assert hasattr( + callback.model_transform, '__num_calls__' + ), "callback.model_transform should have __num_calls__ attribute" + assert callback.model_transform.__num_calls__ == 0, "callback.model_transform should not have been called yet" + assert pl_module.model_transform == callback.model_transform, "pl_module.model_transform should be updated" + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + +class MockLightningModule(pl.LightningModule): + def __init__(self): + super().__init__() + self.model = MockModel() + self.model_transform = lambda m: nn.Sequential(m, nn.ReLU()) + + def forward(self, x): + return self.model(x) diff --git a/tests/lightning/pytorch/callbacks/test_nsys.py b/tests/lightning/pytorch/callbacks/test_nsys.py new file mode 100644 index 000000000000..e8734ad1c1ac --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_nsys.py @@ -0,0 +1,195 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback + + +class TestNsysCallback: + @pytest.fixture(autouse=True) + def setup_mocks(self): + self.cuda_mock = patch('torch.cuda') + self.cudart_mock = patch('torch.cuda.cudart') + self.emit_nvtx_mock = patch('torch.autograd.profiler.emit_nvtx') + self.get_rank_mock = patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + + self.cuda_mock.start() + self.cudart_mock.start() + self.emit_nvtx_mock.start() + self.get_rank_mock.start() + + # Mock CUDA availability + torch.cuda.is_available = MagicMock(return_value=True) + torch.cuda.current_device = MagicMock(return_value=0) + + yield + + self.cuda_mock.stop() + self.cudart_mock.stop() + self.emit_nvtx_mock.stop() + self.get_rank_mock.stop() + + @pytest.fixture + def mock_trainer(self): + trainer = MagicMock() + trainer.strategy.root_device.type = 'cuda' + return trainer + + @pytest.fixture + def mock_pl_module(self): + return MagicMock() + + def test_init_valid_params(self): + """Test initialization with valid parameters.""" + callback = NsysCallback(start_step=10, end_step=20, ranks=[0, 1], gen_shape=True) + assert callback._nsys_profile_start_step == 10 + assert callback._nsys_profile_end_step == 20 + assert callback._nsys_profile_ranks == [0, 1] + assert callback._nsys_profile_gen_shape == True + + def test_init_invalid_params(self): + """Test initialization with invalid parameters.""" + with pytest.raises(AssertionError): + NsysCallback(start_step='10', end_step=20) + + with pytest.raises(AssertionError): + NsysCallback(start_step=10, end_step='20') + + with pytest.raises(AssertionError): + NsysCallback(start_step=20, end_step=10) + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_start_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_start when profiling should start.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0], gen_shape=True) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + + mock_cudart().cudaProfilerStart.assert_called_once() + mock_emit_nvtx.assert_called_once_with(record_shapes=True) + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + def test_on_train_batch_start_no_profiling(self, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module): + """Test on_train_batch_start when profiling should not start.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 9) + + mock_cudart().cudaProfilerStart.assert_not_called() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_end_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_end when profiling should end.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + mock_cudart().cudaProfilerStop.assert_called_once() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_end_no_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_end when profiling should not end.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 19) + + mock_cudart().cudaProfilerStop.assert_not_called() + + def test_non_cuda_device(self, mock_trainer, mock_pl_module): + """Test behavior when the device is not CUDA.""" + mock_trainer.strategy.root_device.type = 'cpu' + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + # No exceptions should be raised, and no profiling calls should be made + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + def test_rank_not_in_profile_ranks(self, mock_get_rank, mock_trainer, mock_pl_module): + """Test behavior when the current rank is not in the profile ranks.""" + mock_get_rank.return_value = 1 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + # No profiling calls should be made + + @pytest.mark.parametrize( + "start_step,end_step,batch_idx,expected_call", + [ + (10, 20, 9, False), + (10, 20, 10, True), + (10, 20, 15, False), + (10, 20, 20, False), + (10, 20, 21, False), + ], + ) + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_profiling_range( + self, + mock_emit_nvtx, + mock_cudart, + mock_get_rank, + start_step, + end_step, + batch_idx, + expected_call, + mock_trainer, + mock_pl_module, + ): + """Test profiling behavior across different batch indices.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=start_step, end_step=end_step, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, batch_idx) + + if expected_call: + mock_cudart().cudaProfilerStart.assert_called_once() + mock_emit_nvtx.assert_called_once() + else: + mock_cudart().cudaProfilerStart.assert_not_called() + mock_emit_nvtx.assert_not_called() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + def test_single_profile_range(self, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module): + """Test behavior with a single profile range.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=40, ranks=[0]) + + # Ensure the device type is 'cuda' + mock_trainer.strategy.root_device.type = 'cuda' + + # Start of range + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + assert mock_cudart().cudaProfilerStart.call_count == 1, "cudaProfilerStart was not called" + + # Middle of range + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 25) + assert mock_cudart().cudaProfilerStart.call_count == 1, "cudaProfilerStart was called again" + + # End of range + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 40) + assert mock_cudart().cudaProfilerStop.call_count == 1, "cudaProfilerStop was not called" diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py new file mode 100644 index 000000000000..81dc7f85bc08 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock, patch + +import torch.nn as nn +from nemo.collections.llm import fn +from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO + + +class TestPEFT: + class DummyPEFT(PEFT): + def transform(self, module, name=None, prefix=None): + return module # No-op transform for testing + + class DummyModel(nn.Module, fn.FNMixin): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + self.conv = nn.Conv2d(3, 3, 3) + + def test_peft_call(self): + model = self.DummyModel() + peft = self.DummyPEFT() + + transformed_model = peft(model) + + assert transformed_model.linear.weight.requires_grad == False + assert transformed_model.conv.weight.requires_grad == False + + def test_peft_setup(self): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + + pl_module.model_transform = peft + peft.setup(trainer, pl_module, "fit") + + assert isinstance(trainer.strategy._checkpoint_io, WrappedAdapterIO) + assert peft.model_transform is not None + assert peft._needs_to_call is True + + @patch('nemo.lightning.pytorch.callbacks.peft.logging') + def test_peft_on_train_epoch_start_with_adapter(self, mock_logging): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + pl_module.model_transform = peft + + peft.setup(trainer, pl_module, "fit") + + assert peft.model_transform is not None + assert peft._needs_to_call is True + + peft.wrapped_io = MagicMock() + peft.wrapped_io.adapter_ckpt_path = "dummy_path" + peft.wrapped_io.load_checkpoint.return_value = {"dummy_state": "dummy_value"} + peft.on_train_epoch_start(trainer, pl_module) + + mock_logging.info.assert_called_once_with("Loading adapters from dummy_path") + trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False) + + def test_peft_on_load_checkpoint(self): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + checkpoint = {} + + peft.on_load_checkpoint(trainer, pl_module, checkpoint) + + assert pl_module.strict_loading == False diff --git a/tests/lightning/pytorch/callbacks/test_preemption.py b/tests/lightning/pytorch/callbacks/test_preemption.py new file mode 100644 index 000000000000..5fcb4a1458ee --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_preemption.py @@ -0,0 +1,114 @@ +import logging +import signal +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +import torch +from pytorch_lightning import Trainer + +from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback, PreemptionException + + +class TestPreemptionCallback: + + @pytest.fixture + def callback(self): + return PreemptionCallback() + + @pytest.fixture + def mock_trainer(self): + trainer = MagicMock(spec=Trainer) + trainer.should_stop = False + return trainer + + def test_init(self, callback): + assert callback.sig == signal.SIGTERM + assert not callback._interrupted + assert callback._handler_context is None + + def test_custom_signal(self): + custom_callback = PreemptionCallback(sig=signal.SIGUSR1) + assert custom_callback.sig == signal.SIGUSR1 + + @pytest.mark.parametrize("initially_supported,becomes_supported", [(False, True), (False, False), (True, True)]) + def test_on_train_batch_start_distributed_init( + self, callback, mock_trainer, initially_supported, becomes_supported + ): + with ( + patch.object(PreemptionCallback, '_check_preemption_support') as mock_check, + patch.object(callback, '_preemption_handler') as mock_handler, + ): + + mock_check.side_effect = [initially_supported, becomes_supported] + + callback.on_train_start(mock_trainer, None) + callback.on_train_batch_start(mock_trainer, None, None, 0) + + expected_call_count = 1 if initially_supported else (1 if becomes_supported else 0) + assert mock_handler.call_count == expected_call_count + + if initially_supported: + mock_handler.assert_called_once_with() + elif becomes_supported: + mock_handler.assert_called_once_with() + else: + mock_handler.assert_not_called() + + @pytest.mark.parametrize( + "is_supported,interrupted,expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], + ) + def test_interrupted_property(self, callback, is_supported, interrupted, expected): + with ( + patch.object(PreemptionCallback, '_check_preemption_support', return_value=is_supported), + patch('torch.distributed.broadcast'), + patch('torch.tensor', return_value=torch.tensor(interrupted)), + patch('torch.cuda.is_available', return_value=True), + patch('torch.cuda.current_device', return_value=0), + ): + callback._interrupted = interrupted + assert callback.interrupted == expected + + def test_on_train_start(self, callback, mock_trainer): + with ( + patch.object(PreemptionCallback, 'preemption_supported', new_callable=PropertyMock) as mock_supported, + patch.object(callback, '_preemption_handler') as mock_handler, + ): + + # Test when preemption is supported + mock_supported.return_value = True + callback.on_train_start(mock_trainer, None) + mock_handler.assert_called_once() + mock_handler.reset_mock() + + # Test when preemption is not supported + mock_supported.return_value = False + callback.on_train_start(mock_trainer, None) + mock_handler.assert_not_called() + + def test_on_train_end(self, callback, mock_trainer): + mock_context = MagicMock() + callback._handler_context = mock_context + callback.on_train_end(mock_trainer, None) + mock_context.__exit__.assert_called_once_with(None, None, None) + + @pytest.mark.parametrize("interrupted", [True, False]) + def test_on_train_batch_end(self, callback, mock_trainer, interrupted): + with patch.object(PreemptionCallback, 'interrupted', new_callable=lambda: property(lambda self: interrupted)): + callback.on_train_batch_end(mock_trainer, None, None, None, 0) + assert mock_trainer.should_stop == interrupted + + def test_on_exception_preemption(self, callback, mock_trainer): + exception = PreemptionException("Test preemption") + callback.on_exception(mock_trainer, None, exception) + assert mock_trainer.should_stop + + def test_on_exception_other(self, callback, mock_trainer): + exception = ValueError("Some other exception") + callback.on_exception(mock_trainer, None, exception) + assert not mock_trainer.should_stop diff --git a/tests/lightning/test_megatron_parallel.py b/tests/lightning/test_megatron_parallel.py index fafd25e49f5a..e504c7eb5c7c 100644 --- a/tests/lightning/test_megatron_parallel.py +++ b/tests/lightning/test_megatron_parallel.py @@ -1,4 +1,5 @@ from collections import defaultdict +from unittest.mock import MagicMock import pytest from megatron.core import parallel_state @@ -123,13 +124,14 @@ def test_add_callbacks(self) -> None: assert callback in callback_connector.callbacks["on_megatron_step_start"] assert callback in callback_connector.callbacks["on_megatron_microbatch_start"] - def test_event(self, mocker) -> None: + def test_event(self) -> None: callback_connector = mp.CallbackConnector() callback = TestCallback() callback_connector.add(callback) - mocker.spy(callback, "on_megatron_step_start") - mocker.spy(callback, "on_megatron_microbatch_start") + # Replace mocker.spy with manual mocking + callback.on_megatron_step_start = MagicMock() + callback.on_megatron_microbatch_start = MagicMock() callback_connector.event("on_megatron_step_start") callback_connector.event("on_megatron_microbatch_start") From 35ce666bbf10eff47fc05e08fafb5fac4a56585a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Thu, 4 Jul 2024 23:04:32 -0700 Subject: [PATCH 06/13] Akoumparouli/mistral import instruct chat template fix (#9567) * use bf16 by defualt mistral conv Signed-off-by: Alexandros Koumparoulis * add chat template Signed-off-by: Alexandros Koumparoulis * use capitalized role names Signed-off-by: Alexandros Koumparoulis --------- Signed-off-by: Alexandros Koumparoulis Co-authored-by: Marc Romeyn --- .../convert_mistral_7b_hf_to_nemo.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py index cb11bb5da564..3a72661499bf 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py @@ -54,7 +54,7 @@ def get_args(): help="Path to Huggingface Mistral-7b checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") - parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--precision", type=str, default="bf16", help="Model precision") args = parser.parse_args() return args @@ -167,7 +167,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -329,6 +329,22 @@ def convert(args): model = model.to(dtype=dtype) model.cfg.use_cpu_initialization = False + if getattr(tokenizer, 'chat_template', None) is not None: + import hashlib + + assert ( + hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest() == "0b629f783db54e02509999196956ff40" + ), "Got unkown chat template" + from omegaconf import OmegaConf, open_dict + + with open_dict(model.cfg): + model.cfg.tokenizer.chat_template = OmegaConf.create( + { + 'prefix': "{_bos_}", + 'roles': {'User': "[INST] {_content_} [/INST]", 'Assistant': "{_content_}{_eos_}"}, + } + ) + model.save_to(args.output_path) logging.info(f'NeMo model saved to: {args.output_path}') From d481674c988fa089c6b4d8c0133e6a3e79cc2261 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Thu, 4 Jul 2024 23:05:04 -0700 Subject: [PATCH 07/13] Remove .cuda calls, use device isntead (#9602) Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/megatron_parallel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 386b9d5070f9..71d9c87f2fe0 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -49,7 +49,7 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: batch = batch[0] if isinstance(batch, dict): - batch = {k: v.cuda() for k, v in batch.items()} + batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} return batch @@ -182,7 +182,7 @@ def __init__( for i, model_module in enumerate(_pipeline): if not cpu: - model_module.cuda(torch.cuda.current_device()) + model_module.cuda(torch.cuda.current_device(), non_blocking=True) for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) @@ -300,7 +300,7 @@ def forward( if forward_only: loss_mean = cast(torch.Tensor, []) else: - loss_mean = torch.tensor(0.0).cuda() + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) self.callbacks.event("on_megatron_log_step_end", **context) self.callbacks.event("on_megatron_step_end", **context) @@ -1018,7 +1018,7 @@ def forward( loss_sum_and_ub_size_all_gpu = torch.cat( [ loss_sum_for_ub.clone().detach().view(1), - torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + torch.tensor([num_valid_tokens_in_ub], device=torch.cuda.current_device()).clone().detach(), ] ) torch.distributed.all_reduce(loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()) @@ -1045,11 +1045,11 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: loss_sum = ( torch.vstack(loss_sum_tensors_list).sum(dim=0) if len(loss_sum_tensors_list) > 0 - else torch.tensor([0.0, 0.0]).cuda() + else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) ) return loss_sum - return torch.tensor(0.0).cuda() + return torch.tensor(0.0, device=torch.cuda.current_device()) def masked_token_loss(tensor: Tensor, mask: Tensor): From 10768ae18dc10499479a532e7ca0a6733b2ce9d3 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Fri, 5 Jul 2024 00:35:26 -0700 Subject: [PATCH 08/13] fix converter defautl args (#9565) * fix converter defautl args Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- .../convert_mixtral_hf_to_nemo.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py index 8183b0d142c1..1bf23224357f 100644 --- a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py @@ -50,11 +50,17 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to Huggingface Mixtral checkpoints", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to Huggingface Mixtral checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") - valid_precision_values = [16, '16', 'bf16', '16-mixed', 'bf16-mixed', 32, '32'] - parser.add_argument("--precision", type=str, default="32", choices=valid_precision_values, help="Model precision") + valid_precision_values = [16, '16', 'bf16', '16-mixed', 'bf16-mixed'] + parser.add_argument( + "--precision", type=str, default="bf16", choices=valid_precision_values, help="Model precision" + ) parser.add_argument('--low-ram', action='store_true') parser.add_argument('--tmp-dir', default='/tmp/mixtral_ckpt_parts/') args = parser.parse_args() @@ -185,7 +191,7 @@ def make_trainer(args, nemo_config): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) From d4a32d0dea3d7201defdad09967b4536fa56e672 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:43:26 -0700 Subject: [PATCH 09/13] mixtral export (#9603) Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/model/mixtral.py | 119 ++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index af1b73dd9109..6256b67515ee 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -186,3 +186,122 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): ) def _import_moe_w1_w3(gate_proj, up_proj): return torch.cat((gate_proj, up_proj), axis=0) + + +@io.model_exporter(MixtralModel, "hf") +class HFMixtralExporter(io.ModelConnector[MixtralModel, "MixtralForCausalLM"]): + def init(self) -> "MixtralForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + # TODO: Make it work with lazy init + # with torch.device("meta"): + # target = self.init() + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + # TODO: Make sure we don't need to do this + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + # MoE + "decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight": "model.layers.*.block_sparse_moe.experts.*.w2.weight", + "decoder.layers.*.mlp.router.weight": "model.layers.*.block_sparse_moe.gate.weight", + # lm-head + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_moe_w1_w3]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "MixtralConfig": + source: MixtralConfig7B = io.load_ckpt(str(self)).model.config + + from transformers import MixtralConfig as HfMixtralConfig + + return HfMixtralConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + max_position_embeddings=source.max_position_embeddings, + seq_length=source.max_position_embeddings, + # RoPe + rope_theta=source.rotary_base, + # transformer config + num_attention_heads=source.num_attention_heads, + num_key_value_heads=source.num_query_groups, + num_local_experts=config.num_moe_experts, + num_experts_per_tok=config.moe_router_topk, + # norm + rms_norm_eps=source.layernorm_epsilon, + # init + initializer_range=source.init_method_std, + # vocab + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight", + target_key=( + "model.layers.*.block_sparse_moe.experts.*.w1.weight", + "model.layers.*.block_sparse_moe.experts.*.w3.weight", + ), +) +def _export_moe_w1_w3(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj From bdb4e89d9ac33d733f8ea7b21552628dda798825 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Fri, 5 Jul 2024 08:11:14 -0700 Subject: [PATCH 10/13] fix: remove non_blocking from PTL's .cuda call (#9618) Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/megatron_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 71d9c87f2fe0..2f2308717004 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -182,7 +182,7 @@ def __init__( for i, model_module in enumerate(_pipeline): if not cpu: - model_module.cuda(torch.cuda.current_device(), non_blocking=True) + model_module.cuda(torch.cuda.current_device()) for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) From 19b1d75b1819108d58684bcb9996867763684561 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Fri, 5 Jul 2024 13:00:01 -0500 Subject: [PATCH 11/13] Alit/mamba tmp (#9612) * adding mamba support * fix import mixins * rm convert jamba * Apply isort and black reformatting Signed-off-by: JRD971000 * more cleanups * use GPT text gen * Apply isort and black reformatting Signed-off-by: JRD971000 * fixing gbs in TP convetor * Apply isort and black reformatting Signed-off-by: JRD971000 * add reqs * add tutorial * minor fix to tutorial * moving finetuning files Signed-off-by: arendu * moving finetuning files Signed-off-by: arendu * address comments * Apply isort and black reformatting Signed-off-by: JRD971000 * address comments * Apply isort and black reformatting Signed-off-by: JRD971000 * add mamba_tmp * remove mamba import * Apply isort and black reformatting Signed-off-by: JRD971000 --------- Signed-off-by: JRD971000 Signed-off-by: arendu Co-authored-by: Ali Taghibakhshi Co-authored-by: JRD971000 Co-authored-by: arendu --- .../conf/megatron_mamba_config.yaml | 191 +++++ .../mamba_change_num_partition.py | 696 ++++++++++++++++++ .../megatron_mamba_finetuning_config.yaml | 315 ++++++++ .../conf/megatron_mamba_generate_config.yaml | 298 ++++++++ .../tuning/megatron_mamba_finetuning.py | 60 ++ .../tuning/megatron_mamba_generate.py | 69 ++ .../language_modeling/megatron_mamba_model.py | 91 +++ .../megatron_mamba_sft_model.py | 47 ++ .../common/text_generation_strategy.py | 3 + .../nlp/parts/mixins/nlp_adapter_mixins.py | 8 +- requirements/requirements_nlp.txt | 1 + .../convert_mamba2_pyt_to_nemo.py | 159 ++++ tutorials/llm/mamba/mamba.rst | 301 ++++++++ 13 files changed, 2236 insertions(+), 3 deletions(-) create mode 100644 examples/nlp/language_modeling/conf/megatron_mamba_config.yaml create mode 100644 examples/nlp/language_modeling/mamba_change_num_partition.py create mode 100644 examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml create mode 100644 examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml create mode 100644 examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py create mode 100644 examples/nlp/language_modeling/tuning/megatron_mamba_generate.py create mode 100644 nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py create mode 100644 nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py create mode 100644 scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py create mode 100644 tutorials/llm/mamba/mamba.rst diff --git a/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml new file mode 100644 index 000000000000..f4f37d7c4ce0 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml @@ -0,0 +1,191 @@ +name: megatron_mamba +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_mamba + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_mamba--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + + +model: + restore_from_path: null + # model parallelism + mcore_gpt: True + micro_batch_size: 1 + global_batch_size: 8 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + expert_model_parallel_size: 1 # expert model parallelism + hybrid_override_pattern: null + vocab_size: 256000 + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 56 + gated_linear_unit: False + add_bias_linear: False + num_query_groups: 8 + mamba_ssm_ngroups: 8 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-5 + num_moe_experts: 16 + moe_router_topk: 2 + moe_aux_loss_coeff: 0.001 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + persist_layer_norm: True + + tokenizer: + library: 'huggingface' + type: 'EleutherAI/gpt-neox-20b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + use_fast: True + + # Distributed checkpoint setup + dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. + dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU + dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + data: + # Path to data must be specified by the user. + # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + data_prefix: [1.0, /path/to/data] + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic, LDDL + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + masked_lm_prob: 0.15 # Probability of replacing a token with mask. + short_seq_prob: 0.1 # Probability of producing a short sequence. + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/mamba_change_num_partition.py b/examples/nlp/language_modeling/mamba_change_num_partition.py new file mode 100644 index 000000000000..bc76b3215a74 --- /dev/null +++ b/examples/nlp/language_modeling/mamba_change_num_partition.py @@ -0,0 +1,696 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import tarfile +import tempfile +from argparse import ArgumentParser + +import torch +from omegaconf import open_dict +from pytorch_lightning import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel +from nemo.collections.nlp.parts.nlp_overrides import ( + NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.utils import logging +from nemo.utils.app_state import AppState + +""" +Usage: + +### Tensor Parallelism conversion ### + +# Megatron Mamba +python /opt/NeMo/examples/nlp/language_modeling/mamba_change_num_partition.py \ + --model_file= \ + --target_file= \ + --tensor_model_parallel_size=1 \ + --target_tensor_model_parallel_size=4 \ + --precision=bf16 \ + --d-model=4096 \ + --mamba-version=2 \ + --mamba2-n-groups=8 \ + --mamba2-head-dim=64 +""" + +tp_split_dim = { + 'word_embeddings.weight': 0, + 'norm.weight': -1, + 'final_norm.weight': -1, + 'output_layer.weight': 0, + # mamba1/2 + 'A_log': 0, + 'D': 0, + 'dt_bias': 0, + 'in_proj.weight': 0, + 'conv1d.weight': 0, + 'conv1d.bias': 0, + 'x_proj.weight': 1, + 'dt_proj.weight': 0, + 'dt_proj.bias': 0, + 'out_proj.weight': 1, + 'mixer.norm.weight': 0, + # mlp + 'linear_fc1.layer_norm_weight': -1, + 'linear_fc1.weight': 0, + 'linear_fc2.weight': 1, + # attention + 'self_attention.linear_proj.weight': 1, + 'self_attention.linear_qkv.layer_norm_weight': -1, + 'self_attention.linear_qkv.weight': 0, +} + + +def get_split_dim(tensor_name): + # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish + if 'norm.weight' in tensor_name: + if 'mixer.norm.weight' in tensor_name: + return tp_split_dim['mixer.norm.weight'] + else: + return tp_split_dim['norm.weight'] + + for key in tp_split_dim.keys(): + if key in tensor_name: + return tp_split_dim[key] + raise Exception("Unknown tensor name {}".format(tensor_name)) + + +def split_tensor_for_tp(params, key, dim, tensor): + + tp_size = params.target_tensor_model_parallel_size + tensor_sliced = [] + if dim == -1: + tensor_sliced = [tensor for i in range(tp_size)] + else: + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + for x, z in zip(x_sliced, z_sliced): + tensor_sliced.append(torch.cat((x, z), dim=dim)) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + x, z, B, C, dt = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_heads, + ], + dim=dim, + ) + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1])) + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + dt_sliced = torch.chunk(dt, tp_size, dim=dim) + + tensor_sliced = [] + for x, z, B, C, dt in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced): + tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim)) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + x, B, C = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + ], + dim=dim, + ) + if 'weight' in key: + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1])) + elif 'bias' in key: + B = torch.reshape(B, (-1, params.mamba_d_state)) + C = torch.reshape(C, (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + + tensor_sliced = [] + for x, B, C in zip(x_sliced, B_sliced, C_sliced): + tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim)) + elif '_extra_state' in key: + pass + else: + tensor_sliced = torch.chunk(tensor, tp_size, dim=dim) + + return tensor_sliced + + +################# +### Utilities ### +################# + + +def force_cpu_model(cfg): + with open_dict(cfg): + # temporarily set to cpu + original_cpu_init = cfg.get('use_cpu_initialization', False) + if 'megatron_amp_O2' in cfg: + amp_o2_key = 'megatron_amp_O2' + original_amp_o2 = cfg.megatron_amp_O2 + elif 'megatron_amp_02' in cfg: + amp_o2_key = 'megatron_amp_02' + original_amp_o2 = cfg.megatron_amp_02 + else: + amp_o2_key, original_amp_o2 = None, None + + # Set new values + cfg.use_cpu_initialization = True + if amp_o2_key is not None: + cfg[amp_o2_key] = False + + # Disable sequence parallelism - Not disabling this gives error when converting the the model to TP=1 + original_sequence_parallel = cfg.get('sequence_parallel', None) + cfg.sequence_parallel = False + + # Setup restore dict + restore_dict = {'use_cpu_initialization': original_cpu_init} # 'megatron_amp_O2': original_amp_o2 + if amp_o2_key is not None: + restore_dict[amp_o2_key] = original_amp_o2 + if original_sequence_parallel is not None: + restore_dict['sequence_parallel'] = original_sequence_parallel + + return cfg, restore_dict + + +def restore_model_config(cfg, original_dict): + with open_dict(cfg): + for key, val in original_dict.items(): + logging.info(f"Restoring model config key ({key}) from {cfg[key]} to original value of {val}") + cfg[key] = val + return cfg + + +def write_tp_pp_split(model, splits, app_state, tp_size, pp_rank, write_path): + """ + Function to write the given TP PP split to NeMo File. + + Save each of the TP ranks in reverse order + This is done so that the last PP rank will save the last TP rank only after all other PP TP ranks are saved + The final rank will then save a new NeMo file with all other ranks inside. + + Args: + model: The model corresponding to the current TP PP split. Contains partial parameters. + splits: Nested List of tensors containing the TP splits of the current model given current PP rank. + Indexed as splits[idx][tp_rank]. + app_state: AppState object. + tp_size: The global tensor-parallel size of the final model. + pp_rank: The local pipeline parallel rank of the final model. + write_path: The path to save the NeMo file. + """ + for tp_rank in range(tp_size - 1, -1, -1): + app_state.pipeline_model_parallel_rank = pp_rank + app_state.tensor_model_parallel_rank = tp_rank + + idx = 0 + for name, param in model.named_parameters(): + split_val = splits[idx][tp_rank].clone() + + if param.shape != split_val.shape: + raise RuntimeError( + f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}." + ) + + param.data = split_val + idx += 1 + + if write_path is not None: + logging.info(f"Writing pp rank {pp_rank} tp rank {tp_rank} to file {write_path}") + model.save_to(write_path) + + +################## +### Converters ### +################## + + +def split_tp_partition_only(args, model, original_model, tp_size, write_path=None, megatron_legacy=False): + + if tp_size < 1: + raise ValueError("TP size must to be >= 1.") + + app_state = AppState() + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_size = 1 + app_state.tensor_model_parallel_size = tp_size + app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + + app_state.pipeline_model_parallel_rank = 0 + app_state.tensor_model_parallel_rank = tp_size - 1 + + idx = 0 + splits = [] + + for ii, (key, original_tensor) in enumerate(original_model.model.state_dict().items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + new_key = key.replace(str(layer_num), str(layer_num), 1) + except: + new_key = key + + if '_extra_state' not in new_key: + split_dim = get_split_dim(new_key) + split = split_tensor_for_tp(args, new_key, split_dim, original_tensor) + + splits.append(split) + idx += 1 + + # Save each of the TP ranks in reverse order + # This is done so that the last PP rank will save the last TP rank only after all other PP TP ranks are saved + # The final rank will then save a new NeMo file with all other ranks inside. + write_tp_pp_split(model, splits, app_state, tp_size, pp_rank=0, write_path=write_path) + + with tarfile.open(write_path, 'r') as tar: + # Extract all contents to the specified path + tar.extractall(path=os.path.dirname(write_path)) + + +def main(): + parser = ArgumentParser() + parser.add_argument("--model_file", type=str, default=None, required=False, help="Path to source .nemo file") + parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") + parser.add_argument( + "--tensor_model_parallel_size", type=int, default=-1, required=False, help="TP size of source model" + ) + parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") + parser.add_argument( + '--pipeline_model_parallel_size', type=int, default=1, required=False, help='PP size of source model' + ) + parser.add_argument( + '--target_pipeline_model_parallel_size', type=int, required=False, default=1, help='PP size of target model' + ) + parser.add_argument( + '--target_pipeline_model_parallel_split_rank', type=int, default=0, help='PP rank to split for Enc-Dec models' + ) + parser.add_argument( + '--virtual_pipeline_model_parallel_size', type=int, default=None, help='Virtual Pipeline parallelism size' + ) + parser.add_argument( + '--ckpt_name', type=str, default=None, help='Checkpoint name to load from for Virtual Parallel' + ) + parser.add_argument( + "--model_class", + type=str, + default="nemo.collections.nlp.models.language_modeling.megatron_mamba_model.MegatronMambaModel", + help="NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", + ) + parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") + parser.add_argument('--num_gpu_per_node', default=8, type=int, help='Number of GPUs per node') + parser.add_argument( + "--megatron_legacy", + action="store_true", + help="Converter for legacy megatron modles that have different q,k,v weight splits", + ) + parser.add_argument( + "--tokenizer_model_path", + type=str, + required=False, + default=None, + help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", + ) + parser.add_argument( + "--tokenizer_vocab_file", + type=str, + required=False, + default=None, + help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", + ) + parser.add_argument('--hparams_file', type=str, default=None, help='Path to hparams file from PTL training') + parser.add_argument( + '--tp_conversion_only', default=True, action='store_true', help='Only convert TP model to TP model' + ) + parser.add_argument('--model_extracted_dir', type=str, default=None, help='Path to pre-extracted model directory') + + parser.add_argument('--d-model', type=int, default=4096) + parser.add_argument('--mamba-version', type=int, default=2) + parser.add_argument('--mamba-d-state', type=int, default=128) + parser.add_argument('--mamba2-n-groups', type=int, default=8) + parser.add_argument('--mamba2-head-dim', type=int, default=64) + + args = parser.parse_args() + + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + precision = args.precision + num_gpu_per_node = int(args.num_gpu_per_node) + if args.precision in ["32", "16"]: + precision = int(float(args.precision)) + + if precision in ["bf16", "bf16-mixed"]: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + pass + else: + logging.warning("BF16 is not supported on this device. Using FP16 instead.") + precision = precision[2:] + + if precision == 32: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + dtype = torch.float32 # fallback + + # Built target directory if it does not exist + target_dir = os.path.split(args.target_file)[0] + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + tp_size = args.tensor_model_parallel_size + tgt_tp_size = args.target_tensor_model_parallel_size + pp_size = args.pipeline_model_parallel_size + tgt_pp_size = args.target_pipeline_model_parallel_size + pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank + vp_size = args.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + convert_vp = vp_size > 1 + if convert_vp: + from megatron.core import parallel_state + + parallel_state.set_virtual_pipeline_model_parallel_world_size(vp_size) + + hparams_filepath = args.hparams_file + if hparams_filepath is None: + logging.warning( + '\n\n\n!!!!!!!!!\n' + 'You are converting a model with virtual pipeline parallelism enabled, \n' + 'but have not passed `hparams_file` argument. \n' + 'This will cause each ckpt file to be temporarily laoded onto GPU memory!\n\n' + 'It is highly recommended to pass `hparams_file` argument to avoid this.\n' + ) + + # Import the class of the model + + if args.model_file is None and args.model_extracted_dir is None: + raise ValueError("Cannot pass model_file and model_extracted_dir as None at the same time.") + + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=args.model_file, + trainer=Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision), + map_location=torch.device("cpu"), + return_config=True, + ) + plugins = [] + if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if precision in [16, '16', '16-mixed']: + scaler = GradScaler( + init_scale=tmp_cfg.get('native_amp_init_scale', 2**32), + growth_interval=tmp_cfg.get('native_amp_growth_interval', 1000), + hysteresis=tmp_cfg.get('hysteresis', 2), + ) + # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + + if tmp_cfg.get('megatron_amp_O2', False): + plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both + # precision plugins and precision to exist + trainer = Trainer(plugins=plugins, devices=1, strategy=NLPDDPStrategy(), accelerator="cpu") + + if tp_size < 0 or pp_size < 0: + logging.info(f"Loading model config from {args.model_file} to get TP and PP size") + model_config_internal = MegatronMambaModel.restore_from( + restore_path=args.model_file, + trainer=trainer, + map_location=torch.device("cpu"), + return_config=True, + ) + + tp_size = model_config_internal.get('tensor_model_parallel_size', 1) + pp_size = model_config_internal.get('pipeline_model_parallel_size', 1) + + # Check if TP conversion only + tp_conversion_only = args.tp_conversion_only + if tp_conversion_only: + logging.info("Converting TP model to TP model only") + + if pp_size > 1: + raise ValueError("Provided `--tp_conversion_only` but `--pipeline_model_parallel_size` > 1") + + if tgt_pp_size > 1: + raise ValueError("Provided `--tp_conversion_only` but `--target_pipeline_model_parallel_size` > 1") + + if pipeline_model_parallel_split_rank > 0: + raise ValueError("Provided `--tp_conversion_only` but `--target_pipeline_model_parallel_split_rank` > 0") + + # Force PP size to 1 + pp_size = 1 + tgt_pp_size = 1 + pipeline_model_parallel_split_rank = 0 + + if vp_size is None or vp_size < 0: + vp_size = 1 + + app_state = AppState() + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size + app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + + world_size = pp_size * tp_size # pseudo world size for simulating load of a specific rank on a single gpu + + app_state.tensor_model_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = 0 + + # Extract tokenizer artifact from the model to temp directory + logging.info("Extracting tokenizer artifact from NeMo file...") + temp_dir = tempfile.mkdtemp() + tokenizer_model_path = None + with tarfile.open(args.model_file, "r") as tar: + for member in tar.getmembers(): + if '.model' in member.name: + extracted_file = tar.extractfile(member) + extracted_file_path = os.path.join(temp_dir, member.name) + + if tokenizer_model_path is None: + logging.info(f"Found tokenizer. Extracting {member.name} to {extracted_file_path}") + + tokenizer_model_path = extracted_file_path + with open(extracted_file_path, "wb") as f: + f.write(extracted_file.read()) + else: + if args.tokenizer_model_path is None: + logging.warning( + f"\n\nFound multiple tokenizer artifacts in the model file.\n" + f"Using only {tokenizer_model_path}.\n" + f"If this is incorrect, manually pass the correct tokenizer using " + f"`--tokenizer_model_path`.\n\n" + ) + + # If input model has TP > 1 or PP > 1 + # Reconstruct the model to have TP = 1 and PP = 1 + # Note that this is a forward loop that will process PP [0..N] TP [0..M] in sequential order. + + # If input model has TP = 1 and PP = 1 + app_state.model_parallel_size = 1 + + save_restore_connector = NLPSaveRestoreConnector() + + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir + + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir + + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + + original_model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + original_model = original_model.to('cpu') + original_model._save_restore_connector = NLPSaveRestoreConnector() + original_model.freeze() + original_model.to(dtype=dtype) + + model.to(dtype=dtype) + + restore_model_config(model.cfg, restore_dict) + + # If target model has TP > 1 or PP > 1 + if tgt_pp_size > 1 or tgt_tp_size > 1: + + # Preserve the TP 1 PP 1 model parameters and names + global_params = [] + global_params.append([p for n, p in model.named_parameters()]) # params + global_params.append([n for n, p in model.named_parameters()]) # names + + logging.debug("Global parameters:") + for idx, (name, p) in enumerate(zip(global_params[1], global_params[0])): + logging.debug(f"{name} - {p.shape}") + + logging.info(f"TP 1 PP 1 Number of Parameters : {len(global_params[0])}") + + world_size = ( + tgt_pp_size * tgt_tp_size + ) # pseudo world size for simulating load of a specific rank on a single gpu + new_global_batch_size = model.cfg.micro_batch_size * world_size + old_global_batch_size = model.cfg.get('global_batch_size', model.cfg.micro_batch_size) + + global_offset = len(global_params[0]) - 1 # -1 cause this indexes the array, range [0, L-1] + logging.info(f"Final layer offset for parameters: {global_offset}") + + for pp_rank in range(tgt_pp_size - 1, -1, -1): # reverse order + + with open_dict(model.cfg): + model.cfg.pipeline_model_parallel_size = tgt_pp_size + model.cfg.tensor_model_parallel_size = tgt_tp_size + + if 'pipeline_model_parallel_split_rank' in model.cfg: + if pipeline_model_parallel_split_rank > 0: + model.cfg.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank + elif pp_size > 1: + logging.warning( + f"Model config has `pipeline_model_parallel_split_rank` set to " + f"{model.cfg.pipeline_model_parallel_split_rank} and target PP " + f"size is {tgt_pp_size}. " + f"Provided `pipeline_model_parallel_split_rank` is " + f"{pipeline_model_parallel_split_rank}. " + f"Be careful that the model config is correct " + f"if encoder-decoder models are being converted." + ) + + model.cfg.global_batch_size = old_global_batch_size # Used for restoration + + # Override flag that forces Model to use AppState instead of Trainer + # to determine the world size, global and local rank + # Used for simulating load of a specific rank on a single gpu + os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" + + # Compute the global rank + global_rank = ( + pp_rank * tgt_tp_size + 0 + ) # tp_rank = 0 needed just for modules, all TP will be merged to this PP rank + + # Update AppState + app_state.world_size = world_size + app_state.global_rank = global_rank + app_state.local_rank = global_rank % num_gpu_per_node + app_state.pipeline_model_parallel_size = tgt_pp_size + app_state.tensor_model_parallel_size = tgt_tp_size + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + trainer = Trainer(plugins=plugins, devices=1, strategy=NLPDDPStrategy(), accelerator="cpu") + if args.tokenizer_model_path is not None: + with open_dict(model.cfg): + model.cfg.tokenizer.model = args.tokenizer_model_path + + else: + if tokenizer_model_path is None: + logging.warning("Could not extract tokenizer model file from checkpoint.") + + else: + # Extract tokenizer info + with open_dict(model.cfg): + model.cfg.tokenizer.model = tokenizer_model_path + + model.cfg, restore_dict = force_cpu_model(model.cfg) + + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size = 1 + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_micro_batch_size = 1 + model.cfg.global_batch_size = 1 + model.cfg.micro_batch_size = 1 + + model = MegatronMambaModel(model.cfg, trainer) + model = model.to('cpu') + model._save_restore_connector = NLPSaveRestoreConnector() + model.freeze() + model.to(dtype=dtype) + + restore_model_config(model.cfg, restore_dict) + + # Update global batch size + if old_global_batch_size % new_global_batch_size != 0 or old_global_batch_size < new_global_batch_size: + logging.info( + f"Global batch size {old_global_batch_size} is not divisible by new global batch size {new_global_batch_size}." + f" The model config will be updated with new global batch size {new_global_batch_size}." + ) + with open_dict(model.cfg): + model.cfg.global_batch_size = new_global_batch_size + + logging.info(f"Global rank: {global_rank} Local rank: {app_state.local_rank} World size: {world_size}") + logging.info(f"PP rank: {pp_rank} TP rank: {0}") + logging.info(f"TP 1 PP 1 Number of Layers : {len(global_params[0])}") + logging.info(f"Remaining layer offset for parameters: {global_offset}") + logging.info("\n") + + # Special case for TP conversion only mode + if tp_conversion_only: + logging.info(f"Skipping PP split due to flag `--tp_conversion_only`") + split_tp_partition_only( + args, model, original_model, tgt_tp_size, args.target_file, args.megatron_legacy + ) + break + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml new file mode 100644 index 000000000000..3684b61bb186 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml @@ -0,0 +1,315 @@ +name: megatron_mamba +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 1 # frequency with which training steps are logged + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + limit_val_batches: 1024 + limit_test_batches: 500 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: True + wandb_logger_kwargs: + project: griffin + name: sft-test + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + restore_from_path: null + # model parallelism + mcore_gpt: True + micro_batch_size: 1 + global_batch_size: 8 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + expert_model_parallel_size: 1 # expert model parallelism + + vocab_size: 65536 + # model architecture + encoder_seq_length: 4096 + hybrid_override_pattern: null + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 64 + gated_linear_unit: False + add_bias_linear: False + num_query_groups: 8 + ngroups_mamba: 8 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-5 + num_moe_experts: 16 + moe_router_topk: 2 + moe_aux_loss_coeff: 0.001 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + persist_layer_norm: True + + + # mixed-precision + attention_softmax_in_fp32: False + + # Distributed checkpoint setup + dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. + dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU + dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + + + tokenizer: + library: 'huggingface' + type: 'EleutherAI/gpt-neox-20b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + use_fast: True + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + peft: + peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: null # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: [1.0] # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + validation_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml new file mode 100644 index 000000000000..2d34aefffc7e --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml @@ -0,0 +1,298 @@ +name: megatron_mamba +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_mamba + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_mamba--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +model: + restore_from_path: null + # model parallelism + mcore_gpt: True + micro_batch_size: 2 + global_batch_size: 2 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + expert_model_parallel_size: 1 # expert model parallelism + hybrid_override_pattern: null + vocab_size: 65536 + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 64 + gated_linear_unit: False + num_query_groups: 8 + ngroups_mamba: 8 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-5 + num_moe_experts: 16 + moe_router_topk: 2 + moe_aux_loss_coeff: 0.001 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + persist_layer_norm: True + add_bias_linear: False + + answer_only_loss: True + + tokenizer: + library: 'huggingface' + type: 'EleutherAI/gpt-neox-20b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + use_fast: True + + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + peft: + peft_scheme: null # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + test_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ??? # Names of the corresponding datasets used to log metrics. + global_batch_size: 1 + micro_batch_size: 1 + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: 'input' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "input" # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + outfile_path: output.txt + compute_attention_mask: True + +# server-related configs +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: True # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server 1058 +chat: False # use the chat interface +chatbot_config: + value: False # whether to inject the value attributes + attributes: + - name: Quality + min: 0 + max: 4 + key: quality + type: int + default: 4 + - name: Toxicity + min: 0 + max: 4 + key: toxcity + type: int + default: 0 + - name: Humor + min: 0 + max: 4 + key: humor + type: int + default: 0 + - name: Creativity + min: 0 + max: 4 + key: creativity + type: int + default: 0 + - name: Violence + min: 0 + max: 4 + key: violence + type: int + default: 0 + - name: Helpfulness + min: 0 + max: 4 + key: helpfulness + type: int + default: 4 + - name: Not_Appropriate + min: 0 + max: 4 + key: not_appropriate + type: int + default: 0 + - name: Language + choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh'] + key: lang + type: list + default: en + + user: User + assistant: Assistant + system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py b/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py new file mode 100644 index 000000000000..0613ef486ec3 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron_mamba_sft_model import MegatronMambaSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_mamba_finetuning_config") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + precision = cfg.trainer.precision + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + # Restore the precision value after Trainer is built. + cfg.trainer.precision = precision + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronMambaSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + model = MegatronMambaSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a check`point instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg)) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + model.add_adapter(peft_cfg_cls(model_cfg)) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py b/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py new file mode 100644 index 000000000000..6f660d552fc6 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.models.language_modeling.megatron_mamba_sft_model import MegatronMambaSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_mamba_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronMambaSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronMambaSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + model = MegatronMambaSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg)) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + trainer.test(model) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py new file mode 100644 index 000000000000..fb8a04b947b0 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# from megatron.core.models.mamba import MambaModel +# from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.utils import logging + + +class MegatronMambaModel(MegatronGPTModel): + """ + Megatron Mamba pretraining. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + + self.vocab_size = cfg.get('vocab_size', 65536) + self.cfg = cfg + super().__init__(cfg=cfg, trainer=trainer) + logging.warning("Overriding mcore_gpt=True") + self.mcore_gpt = True + + def model_provider_func(self, pre_process, post_process): + + self.hybrid_override_pattern = self.cfg.get( + 'hybrid_override_pattern', "M" * self.transformer_config.num_layers + ) + self.transformer_config.add_bias_linear = self.cfg.get('add_bias_linear', False) + self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False) + self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5) + + # TODO @ataghibakhsh: add mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8) once MLM MR merged + # TODO @ataghibakhsh: add the following + '''MambaModel( + config=self.transformer_config, + max_sequence_length=self.cfg.get('encoder_seq_length', 4096), + vocab_size=self.cfg.get('vocab_size', 65536), + mamba_stack_spec=mamba_stack_spec, + hybrid_override_pattern=self.hybrid_override_pattern, + )''' + # after package mismatch is resovled + model = None + + return model + + def forward(self, input_ids, position_ids=None, attention_mask=None, labels=None): + + output_tensor = self.model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, labels=labels + ) + return output_tensor + + def build_transformer_config(self): + transformer_config = super().build_transformer_config() + return transformer_config + + def on_validation_epoch_end(self): + + averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() + return averaged_loss + + def sharded_state_dict(self, prefix: str = ''): + return None + + def _reset_activation_checkpointing_args(self): + return + + def _restore_activation_checkpointing_args(self): + return + + def _reset_sequence_parallelism_args(self): + return + + def _restore_sequence_parallelism_args(self): + return diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py new file mode 100644 index 000000000000..ebcc47004711 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf import DictConfig +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel + + +__all__ = ['MegatronMambaSFTModel'] + + +class MegatronMambaSFTModel(MegatronGPTSFTModel, MegatronMambaModel): + """ + Megatron Jamba Supervised Fine-Tuning + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + + super().__init__(cfg, trainer=trainer) + self.mcore_gpt = True + self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) + + def _reset_activation_checkpointing_args(self): + pass + + def on_validation_model_zero_grad(self) -> None: + """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ + if not self.validation_param_sync_overlap: + MegatronBaseModel.on_validation_model_zero_grad(self) diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 238c01695f42..f51d53ba5944 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -988,6 +988,7 @@ def model_inference_strategy_dispatcher(model, **args): MegatronGPTPromptLearningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel + from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.retro_inference_strategies import ( @@ -998,6 +999,8 @@ def model_inference_strategy_dispatcher(model, **args): if isinstance(model, MegatronGriffinModel): return GriffinModelTextGenerationStrategy(model) + if isinstance(model, MegatronMambaModel): + return GPTModelTextGenerationStrategy(model) if isinstance(model, MegatronNevaModel): return NevaModelTextGenerationStrategy(model) if isinstance(model, MegatronGPTPromptLearningModel): diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 7d294f6085bb..34ca175470ab 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -17,6 +17,7 @@ from typing import List, Optional, Union import torch +from megatron.core.transformer.identity_op import IdentityOp from omegaconf import DictConfig, OmegaConf, open_dict from nemo.utils.model_utils import inject_model_parallel_rank @@ -178,9 +179,10 @@ def _check_and_add_peft_cfg(self, peft_cfg): for layer in layers: if layer.layer_number in (layer_selection or list(range(1, self.cfg.num_layers + 1))): for name, module in layer.named_modules(): - self._check_and_add_adapter( - name, module, adapter_name, adapter_cfg, name_key_to_mcore_mixins - ) + if not isinstance(module, IdentityOp): + self._check_and_add_adapter( + name, module, adapter_name, adapter_cfg, name_key_to_mcore_mixins + ) else: # Non GPT models, as well as GPT+PTuning do not support layer selection if layer_selection is not None: diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 494a9ab6d672..d006ccb7ad65 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -10,6 +10,7 @@ gdown h5py ijson jieba +mamba-ssm==1.2.0.post1 markdown2 matplotlib>=3.3.2 #megatron_core>0.6.0 # add back once mcore on pypi is compatible again diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py new file mode 100644 index 000000000000..9a44f9c2c5c4 --- /dev/null +++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from argparse import ArgumentParser +from collections import defaultdict +import torch +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + +''' +Example + +CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path \ + --output_path \ + --ngroups_mamba 8 \ + --precision bf16 +''' + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_mamba_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument( + "--input_name_or_path", + type=str, + required=True, + ) + parser.add_argument("--ngroups_mamba", type=int, default=8, help="ngroups for Mamba model") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + args = parser.parse_args() + return args + + +def convert(args): + + checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')['model'] + new_state_dict = {} + + if 'backbone' in list(checkpoint_weights.keys())[0]: + + layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'backbone\.layers\.\d+\.', key)] + layer_numbers = set(int(re.search(r'backbone\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) + num_layers = max(layer_numbers) + 1 + + direct_mappings = { + 'model.embedding.word_embeddings.weight': 'backbone.embedding.weight', + 'model.decoder.final_norm.weight': 'backbone.norm_f.weight', + 'model.output_layer.weight': 'lm_head.weight', + } + + for new_key, old_key in direct_mappings.items(): + new_state_dict[new_key] = checkpoint_weights[old_key] + + layer_attributes = [ + 'mixer.A_log', + 'mixer.D', + 'mixer.conv1d.weight', + 'mixer.conv1d.bias', + 'mixer.in_proj.weight', + 'mixer.dt_bias', + 'mixer.out_proj.weight', + 'mixer.norm.weight', + 'norm.weight', + ] + + for i in range(num_layers): + for attr in layer_attributes: + new_key = f'model.decoder.layers.{i}.{attr}' + old_key = f'backbone.layers.{i}.{attr}' + new_state_dict[new_key] = checkpoint_weights[old_key] + + else: + + layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)] + layer_numbers = set(int(re.search(r'decoder\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) + num_layers = max(layer_numbers) + 1 + + new_state_dict = {"model." + key: value for key, value in checkpoint_weights.items()} + + layers = defaultdict(list) + + for key in new_state_dict.keys(): + match = re.match(r'model\.decoder\.layers\.(\d+)\.(\w+)', key) + if match: + index, layer_type = match.groups() + layers[index].append(layer_type) + + layer_pattern = '' + for i in range(max(map(int, layers.keys())) + 1): + index_str = str(i) + layer_types = layers.get(index_str, []) + if 'mixer' in layer_types: + layer_pattern += 'M' + elif 'self_attention' in layer_types: + layer_pattern += '*' + elif 'mlp' in layer_types: + layer_pattern += '-' + else: + raise AssertionError("Layer not found. Each layer must be eiher MLP, Mamba, or Attention") + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + nemo_config.model.vocab_size, nemo_config.model.hidden_size = new_state_dict[ + 'model.embedding.word_embeddings.weight' + ].shape + nemo_config.model.num_layers = num_layers + nemo_config.model.hybrid_override_pattern = layer_pattern + nemo_config.model.ngroups_mamba = args.ngroups_mamba + + if "-" in layer_pattern: + nemo_config.model.ffn_hidden_size = new_state_dict[ + f'model.decoder.layers.{layer_pattern.index("-")}.mlp.linear_fc1.weight' + ].shape[0] + else: + nemo_config.model.ffn_hidden_size = nemo_config.model.hidden_size + + nemo_config.model.use_cpu_initialization = True + + logging.info(f"Loading Mamba2 Pytorch checkpoint : `{args.input_name_or_path}`") + + trainer = MegatronLMPPTrainerBuilder(nemo_config).create_trainer() + nemo_model_from_pyt = MegatronMambaModel(nemo_config.model, trainer) + + nemo_model_from_pyt.load_state_dict(new_state_dict, strict=True) + dtype = torch_dtype_from_precision(args.precision) + nemo_model_from_pyt = nemo_model_from_pyt.to(dtype=dtype) + nemo_model_from_pyt.save_to(args.output_path) + logging.info(f'Mamba2 NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/tutorials/llm/mamba/mamba.rst b/tutorials/llm/mamba/mamba.rst new file mode 100644 index 000000000000..c09a6ae03087 --- /dev/null +++ b/tutorials/llm/mamba/mamba.rst @@ -0,0 +1,301 @@ +Mamba2 and Mamba2-Transformer Hybrid Models Fine-Tuning +======================================================= + +`State Space Models (SSMs) `__ have recently emerged as a promising alternative to transformers. SSMs offer advantages such as linear time complexity relative to sequence length and a constant cache size for inference. These features enable the processing of longer sequences and higher throughput. Despite these benefits, SSMs alone may fall short compared to transformers on tasks that demand strong copying or in-context learning capabilities. + +To harness the strengths of both approaches, SSM-Hybrid models incorporate MLP, Transformer, and SSM blocks in their architecture. As highlighted in `a study by NVIDIA `__, these hybrid models outperform traditional transformers of the same size by achieving faster inference times due to the inclusion of SSM blocks. Based on experimental results, Mamba2-Hybrid models not only surpass transformer baselines in performance but also benefit from increased computational efficiency. + +The Mamba2 models discussed in the `Transformers are SSMs `__ paper are available in five different sizes: 130 million, 370 million, 780 million, 1.3 billion, and 2.7 billion parameters. The Mamba2-Hybrid models, along with their Mamba2 baseline as released by `NVIDIA `__, are provided in an 8 billion parameter size. + +`Low-Rank Adaptation (LoRA) `__ has emerged as a popular Parameter Efficient Fine-Tuning (PEFT) technique that tunes a very small number of additional parameters as compared to full fine-tuning, thereby reducing the compute required. LoRA tuning can be applied to the linear layers in the Transformer and MLP blocks for the Mamba2-Hybrid models. + +`NVIDIA NeMo +Framework `__ provides tools to perform Fine-tuning on Mamba2 and Mamba2-Hybrid to fit your use case. + +Requirements +------------- + +In order to proceed, ensure that you have met the following requirements: + +* Full Fine-Tuning System Configuration + * Small models (130m, 370m, 780m) + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 40GB, for example: 1 x A6000-40GB. + + * Mid-size models (1.3b, 2.7b) + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 80GB, for example: 1 x H100-80GB or 1 x A100-80GB. + + * Large models (8b) + * Access to at least 2 NVIDIA GPUs with a cumulative memory of at least 80GB, for example: 2 x H100-80GB or 2 x A100-80GB. + +* LoRA Fine-Tuning (Mamba2-Hybrid only) System Configuration + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 80GB, for example: 1 x H100-80GB or 1 x A100-80GB. + + + +* A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. + + +* `Authenticate with NVIDIA NGC `_, and download `NGC CLI Tool `_. + + +Step-by-step Guide for Fine-Tuning +---------------------------------- + +Checkpoints from HuggingFace +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Obtain the desired checkpoint from HuggigFace. + +* `Repository `__ for the Mamba2 models from the `Transformers are SSMs paper `__. +* `Repository `__ for the Mamba2 and Mamba2-Hybrid models by `NVIDIA `__. + + +Convert the Pytorch Checkpoint to a NeMo Checkpoint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. Get into NVIDIA Container + +2. Run the conversion script from . For this conversion script, you should provide the PyTorch state dictionary of the model for ``input_name_or_path``, i.e. this argument only accepts a single ``state_dict``. + +.. code:: bash + + CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path \ + --output_path \ + --ngroups_mamba 8 \ + --precision bf16 + +* Note: the ``ngroups_mamba`` parameter should be 1 for the Mamba2 models from the `Transformers are SSMs paper `__ (130m, 370m, 780m, 1.3b, and 2.7b) and 8 for the Mamba2 and Mamba2-Hybrid models by `NVIDIA `__ (both 8b). + +Model (Tensor) Parallelism for the 8b Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* Note: Distributed checkpointing for the Mamba2 and Mamba2-Hybrid models will be implemented in the near future. For now, you should use the method below for converting to Tensor Parallel (TP) of different sizes. + +The HuggingFace checkpoint for the 8b model is for TP of size 1, and so is the ``.nemo`` checkpoint obtained for the previous step. To shard the model weights for a larger TP size, use the script from