Skip to content

Commit

Permalink
Merge pull request #4 from flageval-baai/hzq/fix_intern_2_5
Browse files Browse the repository at this point in the history
Fix device mismatch issue for InternVL 2.5 78B
  • Loading branch information
zhizhou57 authored Dec 19, 2024
2 parents a2754d8 + e4a0196 commit 74afc01
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model_zoo/vlm/intern_vl/model_adapter_v2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flagevalmm.server.utils import parse_args


# modified from https://huggingface.co/OpenGVLab/InternVL2_5-78B
def split_model(model_name):
device_map = {}
world_size = torch.cuda.device_count()
Expand Down Expand Up @@ -44,6 +45,7 @@ def split_model(model_name):
device_map["language_model.model.norm"] = 0
device_map["language_model.lm_head"] = 0
device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
device_map["language_model.model.rotary_emb"] = 0

return device_map

Expand Down Expand Up @@ -158,8 +160,6 @@ def __getitem__(self, index):
pixel_values = []
num_patches_list = []
for img_path in img_paths:
if "dummy" in img_path:
continue
pixel_values.append(load_image(img_path, max_num=12).to(torch.bfloat16))
num_patches_list.append(pixel_values[-1].size(0))

Expand Down Expand Up @@ -216,6 +216,7 @@ def model_init(self, task_info):
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map,
attn_implementation="flash_attention_2",
).eval()

model = self.accelerator.prepare_model(model, evaluation_mode=True)
Expand Down Expand Up @@ -278,6 +279,5 @@ def run_one_task(self, task_name, meta_info):
server_port=args.server_port,
timeout=args.timeout,
extra_cfg=args.cfg,
# extra_cfg={"model_path": "/share/projset/models/vlm/InternVL2-Llama3-76B"}
)
model_adapter.run()

0 comments on commit 74afc01

Please sign in to comment.