Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): Proxy model support count token #996

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dbgpt/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate(self, params: Dict) -> ModelOutput:
return output

def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
return _try_to_count_token(prompt, self.tokenizer, self.model)

async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
Expand Down Expand Up @@ -454,19 +454,25 @@ def _new_metrics_from_model_output(
return metrics


def _try_to_count_token(prompt: str, tokenizer) -> int:
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt

Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
model ([type]): model

Returns:
int: token count, if error return -1

TODO: More implementation
"""
try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel

if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def add_worker(
return True
else:
# TODO Update worker
logger.warn(f"Instance {worker_key} exist")
logger.warning(f"Instance {worker_key} exist")
return False

def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
Expand Down Expand Up @@ -229,7 +229,7 @@ async def model_startup(self, startup_req: WorkerStartupRequest):
)
if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}")
logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):


def __convert_2_gpt_messages(messages: List[ModelMessage]):
chat_round = 0
gpt_messages = []
last_usr_message = ""
system_messages = []

# TODO: We can't change message order in low level
for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content
Expand Down
27 changes: 27 additions & 0 deletions dbgpt/model/proxy/llms/proxy_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
from __future__ import annotations

from typing import Union, List, Optional, TYPE_CHECKING
import logging
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper

if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage

logger = logging.getLogger(__name__)


class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()

def get_params(self) -> ProxyModelParameters:
return self._model_params

def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)
9 changes: 6 additions & 3 deletions dbgpt/model/utils/chatgpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper

if TYPE_CHECKING:
import httpx
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
self._tokenizer = ProxyTokenizerWrapper()

@property
def client(self) -> ClientType:
Expand Down Expand Up @@ -238,10 +240,11 @@ async def get_context_length(self) -> int:
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package
Args:
model (str): The model name.
prompt (str): The prompt.
"""

raise NotImplementedError()
return self._tokenizer.count_token(prompt, model)


class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
Expand Down
80 changes: 80 additions & 0 deletions dbgpt/model/utils/token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import Union, List, Optional, TYPE_CHECKING
import logging

if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage

logger = logging.getLogger(__name__)


class ProxyTokenizerWrapper:
def __init__(self) -> None:
self._support_encoding = True
self._encoding_model = None

def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[str] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[str], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
if not self._support_encoding:
logger.warning(
"model does not support encoding model, can't count token, returning -1"
)
return -1
encoding = self._get_or_create_encoding_model(model_name)
cnt = 0
if isinstance(messages, str):
cnt = len(encoding.encode(messages, disallowed_special=()))
elif isinstance(messages, BaseMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, ModelMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, list):
for message in messages:
cnt += len(encoding.encode(message.content, disallowed_special=()))
else:
logger.warning(
"unsupported type of messages, can't count token, returning -1"
)
return -1
return cnt

def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
"""Get or create encoding model for given model name
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if self._encoding_model:
return self._encoding_model
try:
import tiktoken

logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
return -1
try:
if not model_name:
model_name = "gpt-3.5-turbo"
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
self._encoding_model = tiktoken.get_encoding("cl100k_base")
return self._encoding_model
17 changes: 4 additions & 13 deletions dbgpt/serve/conversation/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import pytest

from dbgpt.storage.metadata import db
Expand Down Expand Up @@ -39,11 +37,9 @@ def test_table_exist():


def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
entity = ServeEntity(**default_entity_dict)
session.add(entity)


def test_entity_unique_key(default_entity_dict):
Expand All @@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):


def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
pass


def test_entity_update(default_entity_dict):
Expand All @@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):

def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
pass


def test_entity_all():
Expand Down
Loading