From 6697bbdbdad31d9bab0ccea71768a3bb4d97c077 Mon Sep 17 00:00:00 2001 From: paul-gibbons <87940629+paul-gibbons@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:37:46 -0700 Subject: [PATCH] Mistral + Mixtral Support for NeVa (#9459) * mistral template support Signed-off-by: paul-gibbons * get_specs neva fix Signed-off-by: paul-gibbons * mistral update Signed-off-by: paul-gibbons * fixed mistral tokenization Signed-off-by: paul-gibbons * text_gen_strategy add mistral support Signed-off-by: paul-gibbons * mistral text_gen fix Signed-off-by: paul-gibbons * Cleaning up neva config Signed-off-by: paul-gibbons * fix llama_2 default text_gen_strategy Signed-off-by: paul-gibbons * Apply isort and black reformatting Signed-off-by: paul-gibbons * fix forward() to account for new embedding optimization in MCore Signed-off-by: paul-gibbons * Apply isort and black reformatting Signed-off-by: paul-gibbons --------- Signed-off-by: paul-gibbons Signed-off-by: paul-gibbons Co-authored-by: paul-gibbons --- .../multimodal/data/neva/conversation.py | 28 ++++++++++++-- .../multimodal/data/neva/neva_dataset.py | 34 ++++++++++++++--- .../models/multimodal_llm/neva/neva_model.py | 38 ++++++++++++++++--- nemo/collections/multimodal/parts/utils.py | 4 +- .../common/text_generation_strategy.py | 21 ++++++++++ 5 files changed, 109 insertions(+), 16 deletions(-) diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 43b1977aa993a..10a6c9e7283dc 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -43,6 +43,7 @@ class SeparatorStyle(Enum): PLAIN = auto() LLAMA_2 = auto() LLAMA_3 = auto() + MISTRAL = auto() NVGPT = auto() @@ -94,11 +95,15 @@ def get_prompt(self): ret += " " else: ret += role + ":" - elif self.sep_style == SeparatorStyle.LLAMA_2: - wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL: + if self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + else: + wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "") wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" - + if self.sep_style == SeparatorStyle.MISTRAL: + ret += DEFAULT_BOS_TOKEN for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" @@ -112,7 +117,10 @@ def get_prompt(self): message = wrap_inst(message) ret += self.sep + " " + message else: - ret += " " + message + " " + self.sep2 + if self.sep_style == SeparatorStyle.LLAMA_2: + ret += " " + message + " " + self.sep2 + else: + ret += message + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) @@ -449,6 +457,17 @@ def dict(self): version="v1_mmtag", ) +conv_mistral = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="mistral", + messages=(), + offset=0, + sep_style=SeparatorStyle.MISTRAL, + sep="", + sep2=DEFAULT_EOS_TOKEN, +) + default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, @@ -466,6 +485,7 @@ def dict(self): "nvgpt": conv_nvgpt, "nv_steerlm": conv_nvgpt, "nv_dpo": conv_nv_dpo, + "mistral": conv_mistral, } if __name__ == "__main__": diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 86d45ded54cfd..7eef677e13a8b 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -426,6 +426,7 @@ def preprocess_llama_2( sources: dict, tokenizer, cfg, + is_mistral: bool = False, ) -> Dict: """ Preprocesses sources for the LLaMA 2 model configuration. @@ -442,7 +443,10 @@ def preprocess_llama_2( - Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model. This includes tokens, labels, and any special processing as defined in the configuration. """ - conv = conversation_lib.conv_llava_llama_2.copy() + if is_mistral: + conv = conversation_lib.conv_mistral.copy() + else: + conv = conversation_lib.conv_llava_llama_2.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates @@ -477,7 +481,10 @@ def preprocess_llama_2( labels = tokens.clone().detach() # Mask labels - sep = "[/INST] " + if is_mistral: + sep = "[/INST]" + else: + sep = "[/INST] " for conversation, target in zip(conversations, labels): rounds = conversation.split(conv.sep2) cur_len = 0 @@ -492,18 +499,23 @@ def preprocess_llama_2( parts[0] += sep round_len = len(tokenizer.text_to_ids(rou + conv.sep2)) - instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2 + + if is_mistral: + instruction_len = len(tokenizer.text_to_ids(parts[0])) - 1 + else: + instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2 + if i > 0: round_len -= 1 # Remove extra token added by sp tokenizer else: instruction_len += 1 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX - cur_len += round_len target[cur_len:] = IGNORE_INDEX # Check if masking working correctly - # print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]) + # masking_test =[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())] + # print(masking_test) if add_extra_token: tokens = tokens[:, :-1].contiguous() @@ -990,7 +1002,10 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = expand2square(frames, tuple(int(x * 255) for x in self.processor.image_mean)) + frames = [ + expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) + for frame in frames + ] frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values'] @@ -1057,6 +1072,13 @@ def expand2square(pil_img, background_color): self.tokenizer, self.multimodal_cfg, ) + elif self.conv_template == "mistral": + data_dict = preprocess_llama_2( + sources, + self.tokenizer, + self.multimodal_cfg, + is_mistral=True, + ) elif self.conv_template == "plain": data_dict = preprocess_plain( sources, diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index cce40da457253..376237e89ecc6 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -75,7 +75,7 @@ HAVE_APEX = False try: - from megatron.core import InferenceParams, dist_checkpointing, parallel_state + from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -154,10 +154,34 @@ def set_media(self, media): self.media = media def forward(self, input_ids, **kwargs): - media = self.media # avoid change the signature of embedding forward function + media = self.media # avoid changing the signature of embedding forward function + + # TODO: Refactor replace_media_embedding to account for MCore's embedding communication optimization + # https://github.com/NVIDIA/Megatron-LM/commit/ee423e7 changes the way we handle embeddings with sequence parallelism + # When using reduce_scatter_embeddings, word_embedding_tensor is now in the following shape: [sequence/tp, batch_size, hidden_size] + # replace_media_embedding currently expects [batch_size, sequence, hidden_size] + + # Check if reduce_scatter_embeddings is enabled in the embedding forward function + apply_reduce_scatter = getattr(self, 'reduce_scatter_embeddings', False) + + # Set reduce_scatter_embeddings to false to keep words_embedding's + # tensor dimesion the same for replace_media_embedding + if apply_reduce_scatter: + self.reduce_scatter_embeddings = False + words_embeddings = super().forward(input_ids, **kwargs) + words_embeddings = self.replace_media_embeddings(input_ids, words_embeddings, media) - return self.replace_media_embeddings(input_ids, words_embeddings, media) + # Scatter embeddings back to each TP rank if reduce_scatter_embeddings is enabled + if apply_reduce_scatter: + words_embeddings = self._apply_reduce_scatter(words_embeddings) + self.reduce_scatter_embeddings = True + + return words_embeddings + + def _apply_reduce_scatter(self, embeddings): + embeddings = embeddings.transpose(0, 1).contiguous() + return tensor_parallel.mappings.scatter_to_sequence_parallel_region(embeddings) def encode_vision_x(self, vision_x: torch.Tensor): """ @@ -193,7 +217,6 @@ def encode_vision_x(self, vision_x: torch.Tensor): def replace_media_embeddings(self, input_ids, inputs_embeds, media): if media is None: return inputs_embeds - batch_size, sequence_length, hidden_size = inputs_embeds.shape # calculate media features without gradients @@ -550,7 +573,12 @@ def dummy(): media_end_id=media_end_id, mcore_gpt=self.mcore_gpt, config=self.transformer_config, - transformer_layer_spec=get_specs(self.spec_name), + transformer_layer_spec=get_specs( + self.spec_name, + self.transformer_config.num_moe_experts, + self.transformer_config.moe_grouped_gemm, + self.transformer_engine, + ), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index b6dee33d24f3a..7eb72b38d0f01 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -135,8 +135,10 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None): # distributed checkpointing if state_dict is None and sharded_state_dict is not None: + is_dist_ckpt = True checkpoint = dict(state_dict=sharded_state_dict) + tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' @@ -501,7 +503,7 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = expand2square(frames, tuple(int(x * 255) for x in processor.image_mean)) + frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames] frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index f51d53ba59440..8f8fe313a5e3d 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -508,6 +508,27 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents ) # HARDCODED FOR NOW data_dict = preprocess_llama_3(sources, tokenizer, multimodal_cfg) + elif multimodal_cfg["conv_template"] == "mistral": + record = { + 'conversations': [ + { + 'from': 'human', + 'value': prompt, + }, + { + 'from': 'gpt', + 'value': '', + }, + ], + } + for turn in record['conversations']: + if turn.get('value') is not None: + turn['value'] = re.sub('', f'{DEFAULT_IMAGE_TOKEN}\n', turn['value']) + list_data_dict.append(record) + sources = preprocess_multimodal( + copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents + ) # HARDCODED FOR NOW + data_dict = preprocess_llama_2(sources, tokenizer, multimodal_cfg, is_mistral=True) elif multimodal_cfg["conv_template"] == "v1": record = { 'conversations': [