Skip to content

Commit

Permalink
Add .token_count() for estimating input tokens (#23)
Browse files Browse the repository at this point in the history
* Improvements to token usage reporting

* Update changelog

* Clean up docstring

* Make token_usage() a method not a property

Just in case we want parameters

* Fix imports

* Rollback breaking changes

* Cleanup

* Doc improvements

* Add .token_count_async(); require the whole data_model

* Slightly more accurate/conservative token count for OpenAI

* Add tests

* Add note

* Tweak changelog

* Tweak docstring
  • Loading branch information
cpsievert authored Dec 19, 2024
1 parent e033684 commit 200e26c
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features

* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).
* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23)

### Bug fixes

* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
* `ChatOpenAI` now correctly includes the relevant `detail` on `ContentImageRemote()` input.
* `ChatGoogle` now correctly logs its `token_usage()`. (#23)


## [0.2.0] - 2024-12-11
Expand Down
55 changes: 54 additions & 1 deletion chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn

if TYPE_CHECKING:
from anthropic.types import (
Expand Down Expand Up @@ -380,6 +380,59 @@ async def stream_turn_async(self, completion, has_data_model, stream) -> Turn:
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = self._client.messages.count_tokens(**kwargs)
return res.input_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = await self._async_client.messages.count_tokens(**kwargs)
return res.input_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=data_model,
)

args_to_keep = [
"messages",
"model",
"system",
"tools",
"tool_choice",
]

return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]:
messages: list["MessageParam"] = []
for turn in turns:
Expand Down
87 changes: 87 additions & 0 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,93 @@ def tokens(

return res

def token_count(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input.
Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.
Parameters
----------
args
The input to get a token count for.
data_model
If the input is meant for data extraction (i.e., `.extract_data()`), then
this should be the Pydantic model that describes the structure of the data to
extract.
Returns
-------
int
The token count for the input.
Note
----
Remember that the token count is an estimate. Also, models based on
`ChatOpenAI()` currently does not take tools into account when
estimating token counts.
Examples
--------
```python
from chatlas import ChatAnthropic
chat = ChatAnthropic()
# Estimate the token count before sending the input
print(chat.token_count("What is 2 + 2?"))
# Once input is sent, you can get the actual input and output
# token counts from the chat object
chat.chat("What is 2 + 2?", echo="none")
print(chat.token_usage())
```
"""

return self.provider.token_count(
*args,
tools=self._tools,
data_model=data_model,
)

async def token_count_async(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input asynchronously.
Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.
Parameters
----------
args
The input to get a token count for.
data_model
If this input is meant for data extraction (i.e., `.extract_data_async()`),
then this should be the Pydantic model that describes the structure of the data
to extract.
Returns
-------
int
The token count for the input.
"""

return await self.provider.token_count_async(
*args,
tools=self._tools,
data_model=data_model,
)

def app(
self,
*,
Expand Down
54 changes: 53 additions & 1 deletion chatlas/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
)
from ._logging import log_model_default
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn

if TYPE_CHECKING:
from google.generativeai.types.content_types import (
Expand Down Expand Up @@ -332,6 +333,55 @@ async def stream_turn_async(
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = self._client.count_tokens(**kwargs)
return res.total_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = await self._client.count_tokens_async(**kwargs)
return res.total_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=data_model,
)

args_to_keep = ["contents", "tools"]

return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]:
contents: list["ContentDict"] = []
for turn in turns:
Expand Down Expand Up @@ -421,6 +471,8 @@ def _as_turn(
usage.candidates_token_count,
)

tokens_log(self, tokens)

finish = message.candidates[0].finish_reason

return Turn(
Expand Down
54 changes: 53 additions & 1 deletion chatlas/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._chat import Chat
from ._content import (
Content,
ContentImage,
ContentImageInline,
ContentImageRemote,
ContentJson,
Expand All @@ -20,7 +21,7 @@
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn
from ._utils import MISSING, MISSING_TYPE, is_testing

if TYPE_CHECKING:
Expand Down Expand Up @@ -351,6 +352,57 @@ async def stream_turn_async(self, completion, has_data_model, stream):
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
try:
import tiktoken
except ImportError:
raise ImportError(
"The tiktoken package is required for token counting. "
"Please install it with `pip install tiktoken`."
)

encoding = tiktoken.encoding_for_model(self._model)

turn = user_turn(*args)

# Count the tokens in image contents
image_tokens = sum(
self._image_token_count(x)
for x in turn.contents
if isinstance(x, ContentImage)
)

# For other contents, get the token count from the actual message param
other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)]
other_full = self._as_message_param([Turn("user", other_contents)])
other_tokens = len(encoding.encode(str(other_full)))

return other_tokens + image_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
return self.token_count(*args, tools=tools, data_model=data_model)

@staticmethod
def _image_token_count(image: ContentImage) -> int:
if isinstance(image, ContentImageRemote) and image.detail == "low":
return 85
else:
# This is just the max token count for an image The highest possible
# resolution is 768 x 2048, and 8 tiles of size 512px can fit inside
# TODO: this is obviously a very conservative estimate and could be improved
# https://platform.openai.com/docs/guides/vision/calculating-costs
return 170 * 8 + 85

@staticmethod
def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
from openai.types.chat import (
Expand Down
17 changes: 17 additions & 0 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pydantic import BaseModel

from ._content import Content
from ._tools import Tool
from ._turn import Turn

Expand Down Expand Up @@ -141,3 +142,19 @@ def value_turn(
completion: ChatCompletionT,
has_data_model: bool,
) -> Turn: ...

@abstractmethod
def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int: ...
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dev = [
"anthropic[bedrock]",
"google-generativeai>=0.8.3",
"numpy>1.24.4",
"tiktoken",
]
docs = [
"griffe>=1",
Expand Down
14 changes: 12 additions & 2 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chatlas import ChatOpenAI, Turn
from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset

Expand Down Expand Up @@ -26,10 +26,20 @@ def test_tokens_method():
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_token_count_method():
chat = ChatOpenAI(model="gpt-4o-mini")
assert chat.token_count("What is 1 + 1?") == 31

chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
assert chat.token_count("What is 1 + 1?") == 16

chat = ChatGoogle(model="gemini-1.5-flash")
assert chat.token_count("What is 1 + 1?") == 9


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down

0 comments on commit 200e26c

Please sign in to comment.