Skip to content

Commit

Permalink
FEAT: add vllm restart check and support internvl multi-image chat (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amumu96 authored Sep 30, 2024
1 parent 4c8aae1 commit 00a9ee1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 42 deletions.
36 changes: 12 additions & 24 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -6483,8 +6483,7 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-1B",
"model_revision": "a9fc14aea824b6ea1d44f8778cad6b35512c4ce1"
"model_id": "OpenGVLab/InternVL2-1B"
},
{
"model_format": "pytorch",
Expand All @@ -6494,17 +6493,15 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-2B",
"model_revision": "422ad7c6335917bfb514958233955512338485a6"
"model_id": "OpenGVLab/InternVL2-2B"
},
{
"model_format": "awq",
"model_size_in_billions": 2,
"quantizations": [
"Int4"
],
"model_id": "OpenGVLab/InternVL2-2B-AWQ",
"model_revision": "701bc3fc098a8a3b686b3b4135cfb77202be89e0"
"model_id": "OpenGVLab/InternVL2-2B-AWQ"
},
{
"model_format": "pytorch",
Expand All @@ -6514,8 +6511,7 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-4B",
"model_revision": "b50544dafada6c41e80bfde2f57cc9b0140fc21c"
"model_id": "OpenGVLab/InternVL2-4B"
},
{
"model_format": "pytorch",
Expand All @@ -6525,17 +6521,15 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-8B",
"model_revision": "3bfd3664dea4f3da628785f5125d30f889701253"
"model_id": "OpenGVLab/InternVL2-8B"
},
{
"model_format": "awq",
"model_size_in_billions": 8,
"quantizations": [
"Int4"
],
"model_id": "OpenGVLab/InternVL2-8B-AWQ",
"model_revision": "9f1a4756b7ae18eb26d8a22b618dfc283e8193b3"
"model_id": "OpenGVLab/InternVL2-8B-AWQ"
},
{
"model_format": "pytorch",
Expand All @@ -6545,17 +6539,15 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-26B",
"model_revision": "b9f3c7e6d575b0115e076a3ffc46fd20b7586899"
"model_id": "OpenGVLab/InternVL2-26B"
},
{
"model_format": "awq",
"model_size_in_billions": 26,
"quantizations": [
"Int4"
],
"model_id": "OpenGVLab/InternVL2-26B-AWQ",
"model_revision": "469e0019ffd251e22ff6501a5c2321964e86ef0d"
"model_id": "OpenGVLab/InternVL2-26B-AWQ"
},
{
"model_format": "pytorch",
Expand All @@ -6565,17 +6557,15 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-40B",
"model_revision": "725a12063bb855c966e30a0617d0ccd9e870d772"
"model_id": "OpenGVLab/InternVL2-40B"
},
{
"model_format": "awq",
"model_size_in_billions": 40,
"quantizations": [
"Int4"
],
"model_id": "OpenGVLab/InternVL2-40B-AWQ",
"model_revision": "d92e140f6dfe8ea9679924c6a31898f42c4e1846"
"model_id": "OpenGVLab/InternVL2-40B-AWQ"
},
{
"model_format": "pytorch",
Expand All @@ -6585,17 +6575,15 @@
"8-bit",
"none"
],
"model_id": "OpenGVLab/InternVL2-Llama3-76B",
"model_revision": "cf7914905f78e9e3560ddbd6f5dfc39becac494f"
"model_id": "OpenGVLab/InternVL2-Llama3-76B"
},
{
"model_format": "awq",
"model_size_in_billions": 76,
"quantizations": [
"Int4"
],
"model_id": "OpenGVLab/InternVL2-Llama3-76B-AWQ",
"model_revision": "1bc796bf80f2ebc7d6a14c15f55217a4600d50a4"
"model_id": "OpenGVLab/InternVL2-Llama3-76B-AWQ"
}
],
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
Expand Down
12 changes: 2 additions & 10 deletions xinference/model/llm/llm_family_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -4334,16 +4334,8 @@
}
],
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"stop_token_ids": [
151643,
151644,
151645
],
"stop": [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>"
]
"stop_token_ids": [],
"stop": []
},
{
"version": 1,
Expand Down
17 changes: 14 additions & 3 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,25 @@ def get_specific_prompt(model_family: str, messages: List[Dict]):
for image_url in image_urls:
fut = executor.submit(_decode_image, image_url)
image_futures.append(fut)
images = [fut.result() for fut in image_futures]
images.extend([fut.result() for fut in image_futures])
if len(image_futures) == 0:
ret += role + "\n" + text + intra_message_sep + "\n"
else:
placeholders = "\n".join(
f"Image-{i+1}: <image>\n"
for i in range(
len(images) - len(image_futures), len(images)
)
)
ret += (
role + "\n" + f"<image>\n{text}" + intra_message_sep + "\n"
role
+ "\n"
+ f"{placeholders}\n{text}"
+ intra_message_sep
+ "\n"
)

if len(images) == 1:
ret = ret.replace("Image-1: <image>\n", "<image>\n")
return ret, images
else:
raise ValueError(f"Invalid model family: {model_family}")
Expand Down
26 changes: 21 additions & 5 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import json
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -47,6 +48,7 @@
ChatModelMixin,
generate_completion_chunk,
)
from .utils import vllm_check

logger = logging.getLogger(__name__)

Expand All @@ -65,6 +67,7 @@ class VLLMModelConfig(TypedDict, total=False):
max_num_seqs: int
quantization: Optional[str]
max_model_len: Optional[int]
limit_mm_per_prompt: Optional[Dict[str, int]]


class VLLMGenerateConfig(TypedDict, total=False):
Expand All @@ -90,9 +93,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
except ImportError:
VLLM_INSTALLED = False

VLLM_SUPPORTED_VISION_MODEL_LIST: List[str] = [
"internvl2",
]
VLLM_SUPPORTED_VISION_MODEL_LIST: List[str] = []
VLLM_SUPPORTED_MODELS = [
"llama-2",
"llama-3",
Expand Down Expand Up @@ -171,6 +172,9 @@ class VLLMGenerateConfig(TypedDict, total=False):
VLLM_SUPPORTED_MODELS.append("llama-3.1")
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")

if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")


class VLLMModel(LLM):
def __init__(
Expand Down Expand Up @@ -305,6 +309,11 @@ def _sanitize_model_config(
model_config.setdefault("max_num_seqs", 256)
model_config.setdefault("quantization", None)
model_config.setdefault("max_model_len", None)
model_config["limit_mm_per_prompt"] = (
json.loads(model_config.get("limit_mm_per_prompt")) # type: ignore
if model_config.get("limit_mm_per_prompt")
else None
)

return model_config

Expand Down Expand Up @@ -434,6 +443,7 @@ def _convert_request_output_to_completion(
usage=usage,
)

@vllm_check
async def async_generate(
self,
prompt: Union[str, Dict[str, Any]],
Expand Down Expand Up @@ -665,6 +675,7 @@ async def _async_to_tool_completion_chunks(
yield self._to_chat_completion_chunk(chunk)
i += 1

@vllm_check
async def async_chat(
self,
messages: List[Dict],
Expand Down Expand Up @@ -741,25 +752,30 @@ def _sanitize_chat_config(
)
return generate_config

@vllm_check
async def async_chat(
self,
messages: List[Dict],
generate_config: Optional[Dict] = None,
request_id: Optional[str] = None,
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
# only support single image, waiting vllm support multi images
model_family = self.model_family.model_family or self.model_family.model_name
prompt, images = self.get_specific_prompt(model_family, messages)

if len(images) == 0:
inputs = {
"prompt": prompt,
}
else:
elif len(images) == 1:
inputs = {
"prompt": prompt,
"multi_modal_data": {"image": images[-1]}, # type: ignore
}
else:
inputs = {
"prompt": prompt,
"multi_modal_data": {"image": images}, # type: ignore
}
generate_config = self._sanitize_chat_config(generate_config)

stream = generate_config.get("stream", None)
Expand Down
42 changes: 42 additions & 0 deletions xinference/model/llm/vllm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os

logger = logging.getLogger(__name__)


def vllm_check(fn):
try:
from vllm.engine.async_llm_engine import AsyncEngineDeadError
except:
return fn

@functools.wraps(fn)
async def _async_wrapper(self, *args, **kwargs):
logger.info("vllm_check")
try:
return await fn(self, *args, **kwargs)
except AsyncEngineDeadError:
logger.info("Detecting vLLM is not health, prepare to quit the process")
try:
self.stop()
except:
# ignore error when stop
pass
# Just kill the process and let xinference auto-recover the model
os._exit(1)

return _async_wrapper

0 comments on commit 00a9ee1

Please sign in to comment.