Skip to content

Commit

Permalink
Mistral + Mixtral Support for NeVa (NVIDIA#9459)
Browse files Browse the repository at this point in the history
* mistral template support

Signed-off-by: paul-gibbons <[email protected]>

* get_specs neva fix

Signed-off-by: paul-gibbons <[email protected]>

* mistral update

Signed-off-by: paul-gibbons <[email protected]>

* fixed mistral tokenization

Signed-off-by: paul-gibbons <[email protected]>

* text_gen_strategy add mistral support

Signed-off-by: paul-gibbons <[email protected]>

* mistral text_gen fix

Signed-off-by: paul-gibbons <[email protected]>

* Cleaning up neva config

Signed-off-by: paul-gibbons <[email protected]>

* fix llama_2 default text_gen_strategy

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* fix forward() to account for new embedding optimization in MCore

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

---------

Signed-off-by: paul-gibbons <[email protected]>
Signed-off-by: paul-gibbons <[email protected]>
Co-authored-by: paul-gibbons <[email protected]>
  • Loading branch information
2 people authored and pablo-garay committed Jul 8, 2024
1 parent 080a689 commit 6697bbd
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 16 deletions.
28 changes: 24 additions & 4 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SeparatorStyle(Enum):
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
MISTRAL = auto()
NVGPT = auto()


Expand Down Expand Up @@ -94,11 +95,15 @@ def get_prompt(self):
ret += " "
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
if self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
else:
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

if self.sep_style == SeparatorStyle.MISTRAL:
ret += DEFAULT_BOS_TOKEN
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
Expand All @@ -112,7 +117,10 @@ def get_prompt(self):
message = wrap_inst(message)
ret += self.sep + " " + message
else:
ret += " " + message + " " + self.sep2
if self.sep_style == SeparatorStyle.LLAMA_2:
ret += " " + message + " " + self.sep2
else:
ret += message + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
Expand Down Expand Up @@ -449,6 +457,17 @@ def dict(self):
version="v1_mmtag",
)

conv_mistral = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="mistral",
messages=(),
offset=0,
sep_style=SeparatorStyle.MISTRAL,
sep="",
sep2=DEFAULT_EOS_TOKEN,
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -466,6 +485,7 @@ def dict(self):
"nvgpt": conv_nvgpt,
"nv_steerlm": conv_nvgpt,
"nv_dpo": conv_nv_dpo,
"mistral": conv_mistral,
}

if __name__ == "__main__":
Expand Down
34 changes: 28 additions & 6 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def preprocess_llama_2(
sources: dict,
tokenizer,
cfg,
is_mistral: bool = False,
) -> Dict:
"""
Preprocesses sources for the LLaMA 2 model configuration.
Expand All @@ -442,7 +443,10 @@ def preprocess_llama_2(
- Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
This includes tokens, labels, and any special processing as defined in the configuration.
"""
conv = conversation_lib.conv_llava_llama_2.copy()
if is_mistral:
conv = conversation_lib.conv_mistral.copy()
else:
conv = conversation_lib.conv_llava_llama_2.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
Expand Down Expand Up @@ -477,7 +481,10 @@ def preprocess_llama_2(
labels = tokens.clone().detach()

# Mask labels
sep = "[/INST] "
if is_mistral:
sep = "[/INST]"
else:
sep = "[/INST] "
for conversation, target in zip(conversations, labels):
rounds = conversation.split(conv.sep2)
cur_len = 0
Expand All @@ -492,18 +499,23 @@ def preprocess_llama_2(
parts[0] += sep

round_len = len(tokenizer.text_to_ids(rou + conv.sep2))
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if is_mistral:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 1
else:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if i > 0:
round_len -= 1 # Remove extra token added by sp tokenizer
else:
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# Check if masking working correctly
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])
# masking_test =[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]
# print(masking_test)

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
Expand Down Expand Up @@ -990,7 +1002,10 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in self.processor.image_mean))
frames = [
expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean))
for frame in frames
]
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down Expand Up @@ -1057,6 +1072,13 @@ def expand2square(pil_img, background_color):
self.tokenizer,
self.multimodal_cfg,
)
elif self.conv_template == "mistral":
data_dict = preprocess_llama_2(
sources,
self.tokenizer,
self.multimodal_cfg,
is_mistral=True,
)
elif self.conv_template == "plain":
data_dict = preprocess_plain(
sources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
HAVE_APEX = False

try:
from megatron.core import InferenceParams, dist_checkpointing, parallel_state
from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -154,10 +154,34 @@ def set_media(self, media):
self.media = media

def forward(self, input_ids, **kwargs):
media = self.media # avoid change the signature of embedding forward function
media = self.media # avoid changing the signature of embedding forward function

# TODO: Refactor replace_media_embedding to account for MCore's embedding communication optimization
# https://github.com/NVIDIA/Megatron-LM/commit/ee423e7 changes the way we handle embeddings with sequence parallelism
# When using reduce_scatter_embeddings, word_embedding_tensor is now in the following shape: [sequence/tp, batch_size, hidden_size]
# replace_media_embedding currently expects [batch_size, sequence, hidden_size]

# Check if reduce_scatter_embeddings is enabled in the embedding forward function
apply_reduce_scatter = getattr(self, 'reduce_scatter_embeddings', False)

# Set reduce_scatter_embeddings to false to keep words_embedding's
# tensor dimesion the same for replace_media_embedding
if apply_reduce_scatter:
self.reduce_scatter_embeddings = False

words_embeddings = super().forward(input_ids, **kwargs)
words_embeddings = self.replace_media_embeddings(input_ids, words_embeddings, media)

return self.replace_media_embeddings(input_ids, words_embeddings, media)
# Scatter embeddings back to each TP rank if reduce_scatter_embeddings is enabled
if apply_reduce_scatter:
words_embeddings = self._apply_reduce_scatter(words_embeddings)
self.reduce_scatter_embeddings = True

return words_embeddings

def _apply_reduce_scatter(self, embeddings):
embeddings = embeddings.transpose(0, 1).contiguous()
return tensor_parallel.mappings.scatter_to_sequence_parallel_region(embeddings)

def encode_vision_x(self, vision_x: torch.Tensor):
"""
Expand Down Expand Up @@ -193,7 +217,6 @@ def encode_vision_x(self, vision_x: torch.Tensor):
def replace_media_embeddings(self, input_ids, inputs_embeds, media):
if media is None:
return inputs_embeds

batch_size, sequence_length, hidden_size = inputs_embeds.shape

# calculate media features without gradients
Expand Down Expand Up @@ -550,7 +573,12 @@ def dummy():
media_end_id=media_end_id,
mcore_gpt=self.mcore_gpt,
config=self.transformer_config,
transformer_layer_spec=get_specs(self.spec_name),
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None):

# distributed checkpointing
if state_dict is None and sharded_state_dict is not None:

is_dist_ckpt = True
checkpoint = dict(state_dict=sharded_state_dict)

tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
Expand Down Expand Up @@ -501,7 +503,7 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in processor.image_mean))
frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames]
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,27 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_3(sources, tokenizer, multimodal_cfg)
elif multimodal_cfg["conv_template"] == "mistral":
record = {
'conversations': [
{
'from': 'human',
'value': prompt,
},
{
'from': 'gpt',
'value': '',
},
],
}
for turn in record['conversations']:
if turn.get('value') is not None:
turn['value'] = re.sub('<image>', f'{DEFAULT_IMAGE_TOKEN}\n', turn['value'])
list_data_dict.append(record)
sources = preprocess_multimodal(
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_2(sources, tokenizer, multimodal_cfg, is_mistral=True)
elif multimodal_cfg["conv_template"] == "v1":
record = {
'conversations': [
Expand Down

0 comments on commit 6697bbd

Please sign in to comment.