Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chat.tokens() gains a values argument #27

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).

### Bug fixes

* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
Expand Down
116 changes: 111 additions & 5 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Optional,
Sequence,
TypeVar,
overload,
)

from pydantic import BaseModel
Expand Down Expand Up @@ -176,17 +177,122 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

def tokens(self) -> list[tuple[int, int] | None]:
@overload
def tokens(self) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["cumulative"],
) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["discrete"],
) -> list[int]: ...

def tokens(
self,
values: Literal["cumulative", "discrete"] = "discrete",
) -> list[int] | list[tuple[int, int] | None]:
"""
Get the tokens for each turn in the chat.

Parameters
----------
values
If "cumulative" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "discrete", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).

Returns
-------
list[tuple[int, int] | None]
A list of tuples, where each tuple contains the start and end token
indices for a turn.
list[int]
A list of token counts for each (non-system) turn in the chat. The
1st turn includes the tokens count for the system prompt (if any).

Raises
------
ValueError
If the chat's turns (i.e., `.get_turns()`) are not in an expected
format. This may happen if the chat history is manually set (i.e.,
`.set_turns()`). In this case, you can inspect the "raw" token
values via the `.get_turns()` method (each turn has a `.tokens`
attribute).
"""
return [turn.tokens for turn in self._turns]

turns = self.get_turns(include_system_prompt=False)

if values == "cumulative":
return [turn.tokens for turn in turns]

if len(turns) == 0:
return []

err_info = (
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
"Consider getting the 'raw' token values via the `.get_turns()` method "
"(each turn has a `.tokens` attribute)."
)

# Sanity checks for the assumptions made to figure out user token counts
if len(turns) == 1:
raise ValueError(
"Expected at least two turns in the chat history. " + err_info
)

if len(turns) % 2 != 0:
raise ValueError(
"Expected an even number of turns in the chat history. " + err_info
)

if turns[0].role != "user":
raise ValueError(
"Expected the 1st non-system turn to have role='user'. " + err_info
)

if turns[1].role != "assistant":
raise ValueError(
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
)

if turns[1].tokens is None:
raise ValueError(
"Expected the 1st assistant turn to contain token counts. " + err_info
)

res: list[int] = [
# Implied token count for the 1st user input
turns[1].tokens[0],
# The token count for the 1st assistant response
turns[1].tokens[1],
]
for i in range(1, len(turns) - 1, 2):
ti = turns[i]
tj = turns[i + 2]
if ti.role != "assistant" or tj.role != "assistant":
raise ValueError(
"Expected even turns to have role='assistant'." + err_info
)
if ti.tokens is None or tj.tokens is None:
raise ValueError(
"Expected role='assistant' turns to contain token counts."
+ err_info
)
res.extend(
[
# Implied token count for the user input
tj.tokens[0] - sum(ti.tokens),
# The token count for the assistant response
tj.tokens[1],
]
)

return res

def app(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_openai_simple_request():
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
assert turn.tokens == (27, 2)
assert turn.finish_reason == "stop"


Expand Down
28 changes: 28 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from chatlas import ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset


def test_tokens_method():
chat = ChatOpenAI()
assert chat.tokens(values="discrete") == []

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10]

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down
Loading