-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(model): Proxy model support count token
- Loading branch information
Showing
16 changed files
with
365 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Union, List, Optional, TYPE_CHECKING | ||
import logging | ||
from dbgpt.model.parameter import ProxyModelParameters | ||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper | ||
|
||
if TYPE_CHECKING: | ||
from dbgpt.core.interface.message import ModelMessage, BaseMessage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ProxyModel: | ||
def __init__(self, model_params: ProxyModelParameters) -> None: | ||
self._model_params = model_params | ||
self._tokenizer = ProxyTokenizerWrapper() | ||
|
||
def get_params(self) -> ProxyModelParameters: | ||
return self._model_params | ||
|
||
def count_token( | ||
self, | ||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]], | ||
model_name: Optional[int] = None, | ||
) -> int: | ||
"""Count token of given messages | ||
Args: | ||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token | ||
model_name (Optional[int], optional): model name. Defaults to None. | ||
Returns: | ||
int: token count, -1 if failed | ||
""" | ||
return self._tokenizer.count_token(messages, model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Union, List, Optional, TYPE_CHECKING | ||
import logging | ||
|
||
if TYPE_CHECKING: | ||
from dbgpt.core.interface.message import ModelMessage, BaseMessage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ProxyTokenizerWrapper: | ||
def __init__(self) -> None: | ||
self._support_encoding = True | ||
self._encoding_model = None | ||
|
||
def count_token( | ||
self, | ||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]], | ||
model_name: Optional[str] = None, | ||
) -> int: | ||
"""Count token of given messages | ||
Args: | ||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token | ||
model_name (Optional[str], optional): model name. Defaults to None. | ||
Returns: | ||
int: token count, -1 if failed | ||
""" | ||
if not self._support_encoding: | ||
logger.warning( | ||
"model does not support encoding model, can't count token, returning -1" | ||
) | ||
return -1 | ||
encoding = self._get_or_create_encoding_model(model_name) | ||
cnt = 0 | ||
if isinstance(messages, str): | ||
cnt = len(encoding.encode(messages, disallowed_special=())) | ||
elif isinstance(messages, BaseMessage): | ||
cnt = len(encoding.encode(messages.content, disallowed_special=())) | ||
elif isinstance(messages, ModelMessage): | ||
cnt = len(encoding.encode(messages.content, disallowed_special=())) | ||
elif isinstance(messages, list): | ||
for message in messages: | ||
cnt += len(encoding.encode(message.content, disallowed_special=())) | ||
else: | ||
logger.warning( | ||
"unsupported type of messages, can't count token, returning -1" | ||
) | ||
return -1 | ||
return cnt | ||
|
||
def _get_or_create_encoding_model(self, model_name: Optional[str] = None): | ||
"""Get or create encoding model for given model name | ||
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||
""" | ||
if self._encoding_model: | ||
return self._encoding_model | ||
try: | ||
import tiktoken | ||
|
||
logger.info( | ||
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, " | ||
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR" | ||
) | ||
except ImportError: | ||
self._support_encoding = False | ||
logger.warn("tiktoken not installed, cannot count tokens, returning -1") | ||
return -1 | ||
try: | ||
if not model_name: | ||
model_name = "gpt-3.5-turbo" | ||
self._encoding_model = tiktoken.model.encoding_for_model(model_name) | ||
except KeyError: | ||
logger.warning( | ||
f"{model_name}'s tokenizer not found, using cl100k_base encoding." | ||
) | ||
self._encoding_model = tiktoken.get_encoding("cl100k_base") | ||
return self._encoding_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.