From 3022f9af7bca3feee88657f5bc0b91a917d25c02 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Fri, 22 Dec 2023 16:28:20 +0800 Subject: [PATCH] [Feature] Support LLaVA 1.5 (#1853) * Support LLaVA 1.5 * Fix lint --- configs/llava/README.md | 30 +-- configs/llava/llava-7b-v1.5_caption.py | 76 +++++++ configs/llava/llava-7b-v1.5_vqa.py | 76 +++++++ configs/llava/llava-7b-v1_caption.py | 21 +- configs/llava/metafile.yml | 28 ++- mmpretrain/models/multimodal/llava/llava.py | 35 ++-- mmpretrain/models/multimodal/llava/modules.py | 196 +++++++++--------- tools/model_converters/llava-delta2mmpre.py | 39 ++-- 8 files changed, 332 insertions(+), 169 deletions(-) create mode 100644 configs/llava/llava-7b-v1.5_caption.py create mode 100644 configs/llava/llava-7b-v1.5_vqa.py diff --git a/configs/llava/README.md b/configs/llava/README.md index 7aaf57d7d13..581abfe5a66 100644 --- a/configs/llava/README.md +++ b/configs/llava/README.md @@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct -**Prepare the checkpoint** - -According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below -script to download and get the merged the checkpoint. - -```shell -python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth -``` - **Use the model** ```python import torch from mmpretrain import get_model, inference_model -model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda') -out = inference_model(model, 'demo/cat-dog.png') +out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda') print(out) # {'pred_caption': 'In the image, there are two cats sitting on a blanket.'} ``` -**Test Command** - -Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). - -Test: - -```shell -python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH -``` - ## Models and results ### Image Caption on COCO -| Model | Params (M) | BLEU-4 | CIDER | Config | Download | -| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: | -| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial | +| Model | Params (M) | Config | Download | +| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: | +| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) | +| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | +| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | ## Citation diff --git a/configs/llava/llava-7b-v1.5_caption.py b/configs/llava/llava-7b-v1.5_caption.py new file mode 100644 index 00000000000..371c9b5f617 --- /dev/null +++ b/configs/llava/llava-7b-v1.5_caption.py @@ -0,0 +1,76 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +image_size = 336 +prompt_tmpl = f'''{meta_prompt} User: +Describe the image in detail. ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + img_size=image_size, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained='https://download.openmmlab.com/mmclassification/v0/clip/' + 'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth', + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='caption', + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/llava-7b-v1.5_vqa.py b/configs/llava/llava-7b-v1.5_vqa.py new file mode 100644 index 00000000000..5cb9812cd98 --- /dev/null +++ b/configs/llava/llava-7b-v1.5_vqa.py @@ -0,0 +1,76 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +image_size = 336 +prompt_tmpl = f'''{meta_prompt} User: +{{question}} ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + img_size=image_size, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained='https://download.openmmlab.com/mmclassification/v0/clip/' + 'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth', + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='vqa', + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(max_new_tokens=100), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id', 'question']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/llava-7b-v1_caption.py b/configs/llava/llava-7b-v1_caption.py index f7558bedd2b..92e2d1fb65a 100644 --- a/configs/llava/llava-7b-v1_caption.py +++ b/configs/llava/llava-7b-v1_caption.py @@ -1,16 +1,9 @@ _base_ = '../_base_/default_runtime.py' meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501 -im_patch_token = '' -patch_size = 14 image_size = 224 -num_patches = (image_size // patch_size)**2 -caption_prompt = ' '.join([ - meta_prompt, - 'User: a photo of\n', - im_patch_token * num_patches, - 'ASSISTANT:', -]) +prompt_tmpl = f'''{meta_prompt} User: +Describe the image in detail. ASSISTANT:''' # model settings model = dict( @@ -22,6 +15,7 @@ type='VisionTransformer', arch='l', patch_size=14, + img_size=image_size, pre_norm=True, norm_cfg=dict(type='LN', eps=1e-5), layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), @@ -32,15 +26,16 @@ 'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'), ), mm_hidden_size=1024, - use_im_start_end=False, - use_mm_proj=True, + use_im_patch=False, + use_im_start_end=True, + mm_proj_depth=1, lang_encoder=dict( type='AutoModelForCausalLM', name_or_path='huggyllama/llama-7b', ), task='caption', - prompt_tmpl=caption_prompt, - generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(max_new_tokens=50), ) # data settings diff --git a/configs/llava/metafile.yml b/configs/llava/metafile.yml index 2b3cfc4dbae..406a214c33a 100644 --- a/configs/llava/metafile.yml +++ b/configs/llava/metafile.yml @@ -21,5 +21,31 @@ Models: Metrics: BLEU-4: null CIDER: null - Weights: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth Config: configs/llava/llava-7b-v1_caption.py + - Name: llava-7b-v1.5_caption + Metadata: + FLOPs: null + Parameters: 7062900736 + In Collection: LLaVA + Results: + - Task: Image Caption + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth + Config: configs/llava/llava-7b-v1.5_caption.py + - Name: llava-7b-v1.5_vqa + Metadata: + FLOPs: null + Parameters: 7062900736 + In Collection: LLaVA + Results: + - Task: Visual Question Answering + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth + Config: configs/llava/llava-7b-v1.5_vqa.py diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 103d81296f0..f829b092146 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -24,8 +24,8 @@ class Llava(BaseModel): use_im_start_end (bool): Whether to use the im_start and im_end tokens mm_vision_select_layer (int): The index from vision encoder output. Defaults to -1. - use_mm_proj (bool): Whether to enable multi-modal projection. - Defaults to True. + mm_proj_depth (int): The number of linear layers for multi-modal + projection. Defaults to 1. load_lang_pretrained (bool): Whether to load the pretrained model of language encoder. Defaults to False. generation_cfg (dict): The extra generation config, accept the keyword @@ -51,9 +51,10 @@ def __init__(self, mm_hidden_size: int, prompt_tmpl: str, task: str = 'caption', + use_im_patch: bool = True, use_im_start_end: bool = False, mm_vision_select_layer: int = -1, - use_mm_proj: bool = True, + mm_proj_depth: int = 1, generation_cfg: dict = dict(), load_lang_pretrained: bool = False, data_preprocessor: Optional[dict] = None, @@ -75,7 +76,9 @@ def __init__(self, # init tokenizer self.tokenizer = TOKENIZER.build(tokenizer) # add Llava special tokens to the tokenizer - self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + if use_im_patch: + self.tokenizer.add_tokens([self.im_patch_token], + special_tokens=True) if use_im_start_end: self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], special_tokens=True) @@ -108,14 +111,12 @@ def __init__(self, vision_encoder=vision_encoder, lang_encoder=lang_encoder, mm_hidden_size=mm_hidden_size, - use_mm_proj=use_mm_proj, + mm_proj_depth=mm_proj_depth, use_im_start_end=use_im_start_end, im_start_token=self.tokenizer.convert_tokens_to_ids( self.im_start_token), im_end_token=self.tokenizer.convert_tokens_to_ids( self.im_end_token), - im_patch_token=self.tokenizer.convert_tokens_to_ids( - self.im_patch_token), mm_vision_select_layer=mm_vision_select_layer) self.generation_cfg = generation_cfg @@ -207,16 +208,24 @@ def preprocess_text(self, data_samples: List[DataSample], Returns: List[DataSample]: Return list of data samples. """ - prompts = [] + tokens = [] for sample in data_samples: - final_prompt = self.prompt_tmpl.format(**sample.to_dict()) - prompts.append(final_prompt) + prompt = self.prompt_tmpl.format(**sample.to_dict()) + input_ids = [] + while '' in prompt: + prefix, _, prompt = prompt.partition('') + input_ids.extend( + self.tokenizer(prefix, add_special_tokens=False).input_ids) + input_ids.append(-200) + if prompt: + input_ids.extend( + self.tokenizer(prompt, add_special_tokens=False).input_ids) + tokens.append(dict(input_ids=input_ids)) self.tokenizer.padding_side = 'left' - input_text = self.tokenizer( - prompts, + input_text = self.tokenizer.pad( + tokens, padding='longest', - truncation=True, return_tensors='pt', max_length=2000, ).to(device) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py index afa6eefadcb..fa3c6bbbcc0 100644 --- a/mmpretrain/models/multimodal/llava/modules.py +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -31,10 +31,10 @@ def __init__(self, lang_encoder, mm_hidden_size, use_im_start_end=True, - use_mm_proj=True, + mm_proj_depth=1, im_start_token: Optional[int] = None, im_end_token: Optional[int] = None, - im_patch_token: Optional[int] = None, + im_token_index: int = -200, mm_vision_select_layer: int = -1): super().__init__(lang_encoder.config) self.vision_tower = vision_encoder @@ -43,16 +43,26 @@ def __init__(self, self.use_im_start_end = use_im_start_end self.im_start_token = im_start_token self.im_end_token = im_end_token - self.im_patch_token = im_patch_token self.mm_hidden_size = mm_hidden_size self.mm_vision_select_layer = mm_vision_select_layer + self.im_token_index = im_token_index self.lang_hidden_size = lang_encoder.config.hidden_size - if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'): + if mm_proj_depth == 1: + # Llava V1 mm_projector = nn.Linear(self.mm_hidden_size, self.lang_hidden_size) self.lang_encoder.model.add_module('mm_projector', mm_projector) - elif not use_mm_proj: + elif mm_proj_depth > 1: + # Llava V1.5 + modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)] + for _ in range(1, mm_proj_depth): + modules.append(nn.GELU()) + modules.append( + nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) + mm_projector = nn.Sequential(*modules) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif mm_proj_depth == 0: self.lang_encoder.model.add_module('mm_projector', nn.Identity()) self.post_init() @@ -80,16 +90,12 @@ def forward( return_dict if return_dict is not None else self.config.use_return_dict) - # decoder outputs consists of - # (dec_features, layer_state, dec_hidden, dec_attn) - if inputs_embeds is None: - inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids) - - inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds, - images) + (input_ids, attention_mask, past_key_values, inputs_embeds, + labels) = self.forward_vision_tower(input_ids, attention_mask, + past_key_values, labels, images) return self.lang_encoder( - input_ids=None, + input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -127,106 +133,93 @@ def prepare_inputs_for_generation(self, def forward_vision_tower( self, input_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - images: Union[torch.FloatTensor, list, None] = None, + attention_mask: torch.LongTensor, + past_key_values: torch.FloatTensor, + labels: torch.LongTensor, + images: Union[torch.FloatTensor, None] = None, ): - if self.use_im_start_end: - assert self.im_start_token is not None - assert self.im_end_token is not None - if images is not None: - assert self.im_patch_token is not None - - if self.vision_tower is None or images is None or ( - input_ids.shape[1] == 1 and not self.training): - return inputs_embeds + if self.vision_tower is None or images is None or input_ids.shape[ + 1] == 1: + if (past_key_values is not None and self.vision_tower is not None + and images is not None and input_ids.shape[1] == 1): + attention_mask = torch.ones( + (attention_mask.shape[0], + past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, + device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels with torch.no_grad(): - if isinstance(images, (list, tuple)): - # variable length images - image_features = [] - for image in images: - feats = self.vision_tower(image.unsqueeze(0)) - image_feature = feats[self.mm_vision_select_layer][:, 1:] - image_features.append(image_feature) - else: - feats = self.vision_tower(images) - image_features = feats[self.mm_vision_select_layer][:, 1:] - - mm_projector = self.lang_encoder.model.mm_projector - if isinstance(images, (list, tuple)): - image_features = [ - mm_projector(image_feature)[0] - for image_feature in image_features - ] - else: - image_features = mm_projector(image_features) + # TODO: support variable number of images (single now) + feats = self.vision_tower(images) + image_features = feats[-1][:, 1:] - dummy_image_features = torch.zeros( - 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) - dummy_image_features = mm_projector(dummy_image_features) + image_features = self.lang_encoder.model.mm_projector(image_features) new_input_embeds = [] - cur_image_idx = 0 - for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): - if (cur_input_ids != self.im_patch_token).all(): + new_labels = [] if labels is not None else None + new_attn_mask = [] if attention_mask is not None else None + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_img = image_features[batch_idx] + + if (cur_input_ids != self.im_token_index).all(): # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + ( - 0. * dummy_image_features).sum() - new_input_embeds.append(cur_input_embeds) - cur_image_idx += 1 + new_input_embeds.append(self.embed_tokens(cur_input_ids)) + if labels is not None: + new_labels.append(labels[batch_idx]) + if attention_mask is not None: + new_attn_mask.append(attention_mask[batch_idx]) continue + + img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0] if self.use_im_start_end: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_start_token).sum() != ( - cur_input_ids == self.im_end_token).sum(): - raise ValueError('The number of image start tokens and ' - 'image end tokens should be the same.') - image_start_tokens = torch.where( - cur_input_ids == self.im_start_token)[0] - for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to( - device=cur_input_embeds.device) - num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + - 1] != self.im_end_token: - raise ValueError('The image end token should follow ' - 'the image start token.') - cur_new_input_embeds = torch.cat( - (cur_input_embeds[:image_start_token_pos + 1], - cur_image_features, - cur_input_embeds[image_start_token_pos + num_patches + - 1:]), - dim=0) - cur_image_idx += 1 - new_input_embeds.append(cur_new_input_embeds) + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx - 1]), + self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]), + cur_img, + self.embed_tokens( + cur_input_ids[img_idx + 1:img_idx + 2]), + self.embed_tokens(cur_input_ids[img_idx + 2:]), + ], + dim=0, + ) else: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_patch_token).sum() != num_patches: - print(f'Debug: num_patches: {num_patches}') - raise ValueError( - 'The number of image patch tokens should ' - 'be the same as the number of image patches.') - masked_indices = torch.where( - cur_input_ids == self.im_patch_token)[0] - mask_index_start = masked_indices[0] - if (masked_indices != torch.arange( - mask_index_start, - mask_index_start + num_patches, - device=masked_indices.device, - dtype=masked_indices.dtype)).any(): - raise ValueError( - 'The image patch tokens should be consecutive.') cur_new_input_embeds = torch.cat( - (cur_input_embeds[:mask_index_start], cur_image_features, - cur_input_embeds[mask_index_start + num_patches:]), - dim=0) - new_input_embeds.append(cur_new_input_embeds) - cur_image_idx += 1 + [ + self.embed_tokens(cur_input_ids[:img_idx]), + cur_img, + self.embed_tokens(cur_input_ids[img_idx + 1:]), + ], + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + + if labels is not None: + cur_new_labels = torch.cat([ + labels[batch_idx, :img_idx], + labels.new_full((cur_img.size(0), ), -100), + labels[batch_idx, img_idx + 1:], + ], + dim=0) + new_labels.append(cur_new_labels) + + if attention_mask is not None: + cur_attn_mask = torch.cat([ + attention_mask[batch_idx, :img_idx], + attention_mask.new_full((cur_img.size(0), ), True), + attention_mask[batch_idx, img_idx + 1:], + ], + dim=0) + new_attn_mask.append(cur_attn_mask) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + if labels is not None: + labels = torch.stack(new_labels, dim=0) + if attention_mask is not None: + attention_mask = torch.stack(new_attn_mask, dim=0) - return inputs_embeds + return None, attention_mask, past_key_values, inputs_embeds, labels @staticmethod def _reorder_cache(past_key_values, beam_idx): @@ -236,3 +229,6 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past + + def embed_tokens(self, input_ids): + return self.lang_encoder.model.embed_tokens(input_ids) diff --git a/tools/model_converters/llava-delta2mmpre.py b/tools/model_converters/llava-delta2mmpre.py index bc51b19d9f9..104ed07d477 100644 --- a/tools/model_converters/llava-delta2mmpre.py +++ b/tools/model_converters/llava-delta2mmpre.py @@ -9,23 +9,21 @@ from transformers.modeling_utils import load_state_dict prog_description = """\ -Merge Llava delta weights and original weights, -and save as MMPreTrain checkpoint. +Convert Llava weights and original weights. """ def parse_args(): parser = argparse.ArgumentParser(description=prog_description) - parser.add_argument( - 'src_path', type=str, help='The original checkpoint dir') - parser.add_argument( - 'delta_path', type=str, help='The delta checkpoint dir') - parser.add_argument('dst_path', type=str, help='The saved checkpoint path') + parser.add_argument('src', type=str, help='The original checkpoint dir') + parser.add_argument('dst', type=str, help='The saved checkpoint path') + parser.add_argument('--delta', type=str, help='The delta checkpoint dir') args = parser.parse_args() return args def load_checkpoint(path: Path): + path = Path(path) if path.is_file(): return torch.load(path) @@ -41,19 +39,23 @@ def load_checkpoint(path: Path): def main(): args = parse_args() - if Path(args.src_path).exists(): - src_path = Path(args.src_path) + if Path(args.src).exists(): + src_path = args.src else: - src_path = Path(snapshot_download(args.src_path)) + src_path = snapshot_download( + args.src, allow_patterns='pytorch_model*.bin') src_state_dict = load_checkpoint(src_path) - if Path(args.delta_path).exists(): - delta_path = Path(args.delta_path) + if args.delta is None: + delta_state_dict = {} + elif Path(args.delta).exists(): + delta_state_dict = load_checkpoint(args.delta) else: - delta_path = Path(snapshot_download(args.delta_path)) - delta_state_dict = load_checkpoint(delta_path) + delta_path = snapshot_download( + args.delta, allow_patterns='pytorch_model*.bin') + delta_state_dict = load_checkpoint(delta_path) - merged_state_dict = OrderedDict() + new_state_dict = OrderedDict() for k, v in src_state_dict.items(): if k in delta_state_dict: delta_v = delta_state_dict.pop(k) @@ -63,12 +65,13 @@ def main(): v = delta_v else: v += delta_v - merged_state_dict['model.lang_encoder.' + k] = v + if 'rotary_emb.inv_freq' not in k: + new_state_dict['model.lang_encoder.' + k] = v for k, v in delta_state_dict.items(): - merged_state_dict['model.lang_encoder.' + k] = v + new_state_dict['model.lang_encoder.' + k] = v - torch.save(merged_state_dict, args.dst_path) + torch.save(new_state_dict, args.dst) print('Done!!')