From 7ed7a02faa386c611c8e702072311aa3bf178775 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 18 Aug 2023 20:46:01 +0800 Subject: [PATCH] FEAT: support generate/chat/create_embedding/register/unregister/registrations method in cmdline (#363) Co-authored-by: UranusSeven <109661872+UranusSeven@users.noreply.github.com> --- README.md | 2 +- README_zh_CN.md | 12 +- doc/source/models/custom.rst | 134 ++++++---- xinference/deploy/cmdline.py | 353 +++++++++++++++++++++---- xinference/deploy/test/test_cmdline.py | 229 ++++++++++++++++ 5 files changed, 621 insertions(+), 109 deletions(-) create mode 100644 xinference/deploy/test/test_cmdline.py diff --git a/README.md b/README.md index 5662b9a4d1..332b01eedf 100644 --- a/README.md +++ b/README.md @@ -226,5 +226,5 @@ For in-depth details on the built-in models, please refer to [built-in models](h - Xinference will download models automatically for you, and by default the models will be saved under `${USER}/.xinference/cache`. -## Custom models \[Experimental\] +## Custom models Please refer to [custom models](https://inference.readthedocs.io/en/latest/models/custom.html). diff --git a/README_zh_CN.md b/README_zh_CN.md index c209e3279c..c245fec1a0 100644 --- a/README_zh_CN.md +++ b/README_zh_CN.md @@ -134,10 +134,10 @@ model = client.get_model(model_uid) chat_history = [] prompt = "What is the largest animal?" model.chat( - prompt, - chat_history, - generate_config={"max_tokens": 1024} - ) + prompt, + chat_history, + generate_config={"max_tokens": 1024} +) ``` 返回值: @@ -206,5 +206,5 @@ $ xinference list --all **注意**: - Xinference 会自动为你下载模型,默认的模型存放路径为 `${USER}/.xinference/cache`。 -## 自定义模型\[Experimental\] -请参考 [自定义模型](https://inference.readthedocs.io/en/latest/models/custom.html). +## 自定义模型 +请参考 [自定义模型](https://inference.readthedocs.io/en/latest/models/custom.html)。 diff --git a/doc/source/models/custom.rst b/doc/source/models/custom.rst index a636a24573..2a9850d827 100644 --- a/doc/source/models/custom.rst +++ b/doc/source/models/custom.rst @@ -1,126 +1,127 @@ .. _models_custom: -============================ -Custom Models (Experimental) -============================ - -Custom models are currently an experimental feature and are expected to be officially released in -version v0.2.0. +============= +Custom Models +============= +Xinference provides a flexible and comprehensive way to integrate, manage, and utilize custom models. Define a custom model ~~~~~~~~~~~~~~~~~~~~~ Define a custom model based on the following template: -.. code-block:: python +.. code-block:: json - custom_model = { + { "version": 1, - # model name. must start with a letter or a - # digit, and can only contain letters, digits, - # underscores, or dashes. "model_name": "custom-llama-2", - # supported languages "model_lang": [ "en" ], - # model abilities. could be "embed", "generate" - # and "chat". "model_ability": [ "generate" ], - # model specifications. "model_specs": [ { - # model format. "model_format": "pytorch", "model_size_in_billions": 7, - # quantizations. "quantizations": [ "4-bit", "8-bit", "none" ], - # hugging face model ID. "model_id": "meta-llama/Llama-2-7b", - # when model_uri is present, xinference will load the model from the given RUI. "model_uri": "file:///path/to/llama-2-7b" }, { - # model format. - "model_format": "pytorch", - "model_size_in_billions": 13, - # quantizations. - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - # hugging face model ID. - "model_id": "meta-llama/Llama-2-13b" - }, - { - # model format. "model_format": "ggmlv3", - # quantizations. "model_size_in_billions": 7, "quantizations": [ "q4_0", "q8_0" - ] - # hugging face model ID. + ], "model_id": "TheBloke/Llama-2-7B-GGML", - # an f-string that takes a quantization. "model_file_name_template": "llama-2-7b.ggmlv3.{quantization}.bin" } ], - # prompt style, required by chat models. - # for more details, see: xinference/model/llm/tests/test_utils.py - "prompt_style": None } * model_name: A string defining the name of the model. The name must start with a letter or a digit and can only contain letters, digits, underscores, or dashes. * model_lang: A list of strings representing the supported languages for the model. Example: ["en"], which means that the model supports English. * model_ability: A list of strings defining the abilities of the model. It could include options like "embed", "generate", and "chat". In this case, the model has the ability to "generate". * model_specs: An array of objects defining the specifications of the model. These include: - * model_format: A string that defines the model format, could be "pytorch" or "ggmlv3". - * model_size_in_billions: An integer defining the size of the model in billions of parameters. - * quantizations: A list of strings defining the available quantizations for the model. For PyTorch models, it could be "4-bit", "8-bit", or "none". For ggmlv3 models, the quantizations should correspond to values that work with the ``model_file_name_template``. - * model_id: A string representing the model ID, possibly referring to an identifier used by Hugging Face. - * model_uri: A string representing the URI where the model can be loaded from, such as "file:///path/to/llama-2-7b". If model URI is absent, Xinference will try to download the model from Hugging Face with the model ID. - * model_file_name_template: Required by ggml models. An f-string template used for defining the model file name based on the quantization. + * model_format: A string that defines the model format, could be "pytorch" or "ggmlv3". + * model_size_in_billions: An integer defining the size of the model in billions of parameters. + * quantizations: A list of strings defining the available quantizations for the model. For PyTorch models, it could be "4-bit", "8-bit", or "none". For ggmlv3 models, the quantizations should correspond to values that work with the ``model_file_name_template``. + * model_id: A string representing the model ID, possibly referring to an identifier used by Hugging Face. + * model_uri: A string representing the URI where the model can be loaded from, such as "file:///path/to/llama-2-7b". If model URI is absent, Xinference will try to download the model from Hugging Face with the model ID. + * model_file_name_template: Required by ggml models. An f-string template used for defining the model file name based on the quantization. * prompt_style: An optional field that could be required by chat models to define the style of prompts. The given example has this set to None, but additional details could be found in a referenced file xinference/model/llm/tests/test_utils.py. -Register the Custom Model -~~~~~~~~~~~~~~~~~~~~~~~~~ +Register a Custom Model +~~~~~~~~~~~~~~~~~~~~~~~ + +Register a custom model programmatically: .. code-block:: python import json from xinference.client import Client + with open('model.json') as fd: + model = fd.read() + # replace with real xinference endpoint - endpoint = "http://localhost:9997" + endpoint = 'http://localhost:9997' client = Client(endpoint) - client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False) + client.register_model(model_type="LLM", model=model, persist=False) +Or via CLI: +.. code-block:: bash -Load the Custom Model -~~~~~~~~~~~~~~~~~~~~~ + xinference register --model-type LLM --file model.json --persist + +List the Built-in and Custom Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +List built-in and custom models programmatically: .. code-block:: python - uid = client.launch_model(model_name='custom-llama-2') + registrations = client.list_model_registrations(model_type="LLM") + +Or via CLI: + +.. code-block:: bash + + xinference registrations --model-type LLM -Run the Custom Model -~~~~~~~~~~~~~~~~~~~~ +Launch the Custom Model +~~~~~~~~~~~~~~~~~~~~~~~ + +Launch the custom model programmatically: + +.. code-block:: python + + uid = client.launch_model(model_name='custom-llama-2', model_format='pytorch') + +Or via CLI: + +.. code-block:: bash + + xinference launch --model-name custom-llama-2 --model-format pytorch + +Interact with the Custom Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Invoke the model programmatically: .. code-block:: python model = client.get_model(model_uid=uid) - model.generate("What is the largest animal in the world?") + model.generate('What is the largest animal in the world?') Result: @@ -145,3 +146,24 @@ Result: "total_tokens":33 } } + +Or via CLI, replace ``${UID}`` with real model UID: + +.. code-block:: bash + + xinference generate --model-uid ${UID} + +Unregister the Custom Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Unregister the custom model programmatically: + +.. code-block:: python + + model = client.unregister_model(model_type='LLM', model_name='custom-llama-2') + +Or via CLI: + +.. code-block:: bash + + xinference unregister --model-type LLM --model-name custom-llama-2 diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index f8261dd1b6..6b67c67c61 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -11,23 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import asyncio import configparser import logging import os import sys -from typing import Optional +from typing import List, Optional import click from xoscar.utils import get_next_port from .. import __version__ -from ..client import RESTfulClient +from ..client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulClient, + RESTfulGenerateModelHandle, +) from ..constants import ( XINFERENCE_DEFAULT_DISTRIBUTED_HOST, XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DEFAULT_LOCAL_HOST, XINFERENCE_ENV_ENDPOINT, ) +from ..isolation import Isolation +from ..types import ChatCompletionMessage + +try: + # provide elaborate line editing and history features. + # https://docs.python.org/3/library/functions.html#input + import readline # noqa: F401 +except ImportError: + pass def get_config_string(log_level: str) -> str: @@ -146,6 +163,92 @@ def worker(log_level: str, endpoint: Optional[str], host: str): ) +@cli.command("register") +@click.option( + "--endpoint", + "-e", + type=str, +) +@click.option("--model-type", "-t", default="LLM", type=str) +@click.option("--file", "-f", type=str) +@click.option("--persist", "-p", is_flag=True) +def register_model( + endpoint: Optional[str], + model_type: str, + file: str, + persist: bool, +): + endpoint = get_endpoint(endpoint) + with open(file) as fd: + model = fd.read() + + client = RESTfulClient(base_url=endpoint) + client.register_model( + model_type=model_type, + model=model, + persist=persist, + ) + + +@cli.command("unregister") +@click.option( + "--endpoint", + "-e", + type=str, +) +@click.option("--model-type", "-t", default="LLM", type=str) +@click.option("--model-name", "-n", type=str) +def unregister_model( + endpoint: Optional[str], + model_type: str, + model_name: str, +): + endpoint = get_endpoint(endpoint) + + client = RESTfulClient(base_url=endpoint) + client.unregister_model( + model_type=model_type, + model_name=model_name, + ) + + +@cli.command("registrations") +@click.option( + "--endpoint", + "-e", + type=str, +) +@click.option("--model-type", "-t", default="LLM", type=str) +def list_model_registrations( + endpoint: Optional[str], + model_type: str, +): + from tabulate import tabulate + + endpoint = get_endpoint(endpoint) + + client = RESTfulClient(base_url=endpoint) + registrations = client.list_model_registrations(model_type=model_type) + + table = [] + for registration in registrations: + model_name = registration["model_name"] + model_family = client.get_model_registration(model_type, model_name) + table.append( + [ + model_type, + model_family["model_name"], + model_family["model_lang"], + model_family["model_ability"], + registration["is_builtin"], + ] + ) + print( + tabulate(table, headers=["Type", "Name", "Language", "Ability", "Is-built-in"]), + file=sys.stderr, + ) + + @cli.command("launch") @click.option( "--endpoint", @@ -182,56 +285,39 @@ def model_launch( "-e", type=str, ) -@click.option("--all", is_flag=True) -def model_list(endpoint: Optional[str], all: bool): +def model_list(endpoint: Optional[str]): from tabulate import tabulate - # TODO: get from the supervisor - from ..model.llm import BUILTIN_LLM_FAMILIES - endpoint = get_endpoint(endpoint) + client = RESTfulClient(base_url=endpoint) table = [] - if all: - for model_family in BUILTIN_LLM_FAMILIES: - table.append( - [ - model_family.model_name, - model_family.model_lang, - model_family.model_ability, - ] - ) - - print( - tabulate(table, headers=["Name", "Language", "Ability"]), - file=sys.stderr, - ) - else: - client = RESTfulClient(base_url=endpoint) - models = client.list_models() - for model_uid, model_spec in models.items(): - table.append( - [ - model_uid, - model_spec["model_name"], - model_spec["model_format"], - model_spec["model_size_in_billions"], - model_spec["quantization"], - ] - ) - print( - tabulate( - table, - headers=[ - "ModelUid", - "Name", - "Format", - "Size (in billions)", - "Quantization", - ], - ), - file=sys.stderr, + models = client.list_models() + for model_uid, model_spec in models.items(): + table.append( + [ + model_uid, + model_spec["model_type"], + model_spec["model_name"], + model_spec["model_format"], + model_spec["model_size_in_billions"], + model_spec["quantization"], + ] ) + print( + tabulate( + table, + headers=[ + "UID", + "Type", + "Name", + "Format", + "Size (in billions)", + "Quantization", + ], + ), + file=sys.stderr, + ) @cli.command("terminate") @@ -251,5 +337,180 @@ def model_terminate( client.terminate_model(model_uid=model_uid) +@cli.command("generate") +@click.option( + "--endpoint", + "-e", + type=str, +) +@click.option("--model-uid", type=str) +@click.option("--max_tokens", default=256, type=int) +@click.option("--stream", default=True, type=bool) +def model_generate( + endpoint: Optional[str], + model_uid: str, + max_tokens: int, + stream: bool, +): + endpoint = get_endpoint(endpoint) + if stream: + # TODO: when stream=True, RestfulClient cannot generate words one by one. + # So use Client in temporary. The implementation needs to be changed to + # RestfulClient in the future. + async def generate_internal(): + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("Prompt: ") + if prompt == "": + break + print(f"Completion: {prompt}", end="", file=sys.stdout) + async for chunk in model.generate( + prompt=prompt, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ): + choice = chunk["choices"][0] + if "text" not in choice: + continue + else: + print(choice["text"], end="", flush=True, file=sys.stdout) + print("\n", file=sys.stdout) + + client = Client(endpoint=endpoint) + model = client.get_model(model_uid=model_uid) + + loop = asyncio.get_event_loop() + coro = generate_internal() + + if loop.is_running(): + isolation = Isolation(asyncio.new_event_loop(), threaded=True) + isolation.start() + isolation.call(coro) + else: + task = loop.create_task(coro) + try: + loop.run_until_complete(task) + except KeyboardInterrupt: + task.cancel() + loop.run_until_complete(task) + # avoid displaying exception-unhandled warnings + task.exception() + else: + restful_client = RESTfulClient(base_url=endpoint) + restful_model = restful_client.get_model(model_uid=model_uid) + if not isinstance( + restful_model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle) + ): + raise ValueError(f"model {model_uid} has no generate method") + + while True: + prompt = input("User: ") + if prompt == "": + break + print(f"Assistant: {prompt}", end="", file=sys.stdout) + response = restful_model.generate( + prompt=prompt, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ) + if not isinstance(response, dict): + raise ValueError("generate result is not valid") + print(f"{response['choices'][0]['text']}\n", file=sys.stdout) + + +@cli.command("chat") +@click.option( + "--endpoint", + "-e", + type=str, +) +@click.option("--model-uid", type=str) +@click.option("--max_tokens", default=256, type=int) +@click.option("--stream", default=True, type=bool) +def model_chat( + endpoint: Optional[str], + model_uid: str, + max_tokens: int, + stream: bool, +): + # TODO: chat model roles may not be user and assistant. + endpoint = get_endpoint(endpoint) + chat_history: "List[ChatCompletionMessage]" = [] + if stream: + # TODO: when stream=True, RestfulClient cannot generate words one by one. + # So use Client in temporary. The implementation needs to be changed to + # RestfulClient in the future. + async def chat_internal(): + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("User: ") + if prompt == "": + break + chat_history.append(ChatCompletionMessage(role="user", content=prompt)) + print("Assistant: ", end="", file=sys.stdout) + response_content = "" + async for chunk in model.chat( + prompt=prompt, + chat_history=chat_history, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ): + delta = chunk["choices"][0]["delta"] + if "content" not in delta: + continue + else: + response_content += delta["content"] + print(delta["content"], end="", flush=True, file=sys.stdout) + print("\n", file=sys.stdout) + chat_history.append( + ChatCompletionMessage(role="assistant", content=response_content) + ) + + client = Client(endpoint=endpoint) + model = client.get_model(model_uid=model_uid) + + loop = asyncio.get_event_loop() + coro = chat_internal() + + if loop.is_running(): + isolation = Isolation(asyncio.new_event_loop(), threaded=True) + isolation.start() + isolation.call(coro) + else: + task = loop.create_task(coro) + try: + loop.run_until_complete(task) + except KeyboardInterrupt: + task.cancel() + loop.run_until_complete(task) + # avoid displaying exception-unhandled warnings + task.exception() + else: + restful_client = RESTfulClient(base_url=endpoint) + restful_model = restful_client.get_model(model_uid=model_uid) + if not isinstance( + restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) + ): + raise ValueError(f"model {model_uid} has no chat method") + + while True: + prompt = input("User: ") + if prompt == "": + break + chat_history.append(ChatCompletionMessage(role="user", content=prompt)) + print("Assistant: ", end="", file=sys.stdout) + response = restful_model.chat( + prompt=prompt, + chat_history=chat_history, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ) + if not isinstance(response, dict): + raise ValueError("chat result is not valid") + response_content = response["choices"][0]["message"]["content"] + print(f"{response_content}\n", file=sys.stdout) + chat_history.append( + ChatCompletionMessage(role="assistant", content=response_content) + ) + + if __name__ == "__main__": cli() diff --git a/xinference/deploy/test/test_cmdline.py b/xinference/deploy/test/test_cmdline.py new file mode 100644 index 0000000000..8fb9d381c1 --- /dev/null +++ b/xinference/deploy/test/test_cmdline.py @@ -0,0 +1,229 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +from click.testing import CliRunner + +from ...client import Client +from ..cmdline import ( + list_model_registrations, + model_chat, + model_generate, + model_list, + model_terminate, + register_model, + unregister_model, +) + + +@pytest.mark.parametrize("stream", [True, False]) +def test_cmdline(setup, stream): + endpoint, _ = setup + runner = CliRunner() + + # launch model + """ + result = runner.invoke( + model_launch, + [ + "--endpoint", + endpoint, + "--model-name", + "orca", + "--size-in-billions", + 3, + "--model-format", + "ggmlv3", + "--quantization", + "q4_0", + ], + ) + assert result.exit_code == 0 + assert "Model uid: " in result.stdout + + model_uid = result.stdout.split("Model uid: ")[1].strip() + """ + # if use `model_launch` command to launch model, CI will fail. + # So use client to launch model in temporary + client = Client(endpoint) + model_uid = client.launch_model( + model_name="orca", model_size_in_billions=3, quantization="q4_0" + ) + assert len(model_uid) != 0 + + # list model + result = runner.invoke( + model_list, + [ + "--endpoint", + endpoint, + ], + ) + assert result.exit_code == 0 + assert model_uid in result.stdout + + # model generate + result = runner.invoke( + model_generate, + [ + "--endpoint", + endpoint, + "--model-uid", + model_uid, + "--stream", + stream, + ], + input="Once upon a time, there was a very old computer.\n\n", + ) + assert result.exit_code == 0 + assert len(result.stdout) != 0 + print(result.stdout) + + # model chat + result = runner.invoke( + model_chat, + [ + "--endpoint", + endpoint, + "--model-uid", + model_uid, + "--stream", + stream, + ], + input="Write a poem.\n\n", + ) + assert result.exit_code == 0 + assert len(result.stdout) != 0 + print(result.stdout) + + # terminate model + result = runner.invoke( + model_terminate, + [ + "--endpoint", + endpoint, + "--model-uid", + model_uid, + ], + ) + assert result.exit_code == 0 + + # list model again + result = runner.invoke( + model_list, + [ + "--endpoint", + endpoint, + ], + ) + assert result.exit_code == 0 + assert model_uid not in result.stdout + + +def test_cmdline_of_custom_model(setup): + endpoint, _ = setup + runner = CliRunner() + + # register custom model + custom_model_desc = """{ + "version": 1, + "model_name": "custom_model", + "model_lang": [ + "en", "zh" + ], + "model_ability": [ + "embed", + "chat" + ], + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "ziqingyang/chinese-alpaca-2-7b" + } + ], + "prompt_style": { + "style_name": "ADD_COLON_SINGLE", + "system_prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + "roles": [ + "Instruction", + "Response" + ], + "intra_message_sep": "\\n\\n### " + } +}""" + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_filename = temp_file.name + temp_file.write(custom_model_desc.encode("utf-8")) + result = runner.invoke( + register_model, + [ + "--endpoint", + endpoint, + "--model-type", + "LLM", + "--file", + temp_filename, + ], + ) + assert result.exit_code == 0 + os.unlink(temp_filename) + + # list model registrations + result = runner.invoke( + list_model_registrations, + [ + "--endpoint", + endpoint, + "--model-type", + "LLM", + ], + ) + assert result.exit_code == 0 + assert "custom_model" in result.stdout + + # unregister custom model + result = runner.invoke( + unregister_model, + [ + "--endpoint", + endpoint, + "--model-type", + "LLM", + "--model-name", + "custom_model", + ], + ) + assert result.exit_code == 0 + + # list model registrations again + result = runner.invoke( + list_model_registrations, + [ + "--endpoint", + endpoint, + "--model-type", + "LLM", + ], + ) + assert result.exit_code == 0 + assert "custom_model" not in result.stdout