Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching #5603

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ api_key = "your-api-key"
# Cost per output token
#output_cost_per_token = 0.0

# Discount to apply for cache hits (e.g., 0.8 for 80% discount)
# It will apply from input_cost_per_token if set. Ignored if input_cost_per_token is not set.
#cache_hit_discount = 0.8

# Premium to apply for cache writes (e.g., 0.2 for 20% premium)
# It will apply from input_cost_per_token. Ignored if input_cost_per_token is not set.
# Only used for Anthropic models
#cache_write_premium = 0.2

# Custom LLM provider
#custom_llm_provider = ""

Expand Down
3 changes: 1 addition & 2 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ def _get_messages(self, state: State) -> list[Message]:
messages.append(
Message(
role='user',
content=[TextContent(text=example_message)],
cache_prompt=self.llm.is_caching_prompt_active(),
content=[TextContent(text=example_message)]
)
)

Expand Down
4 changes: 4 additions & 0 deletions openhands/core/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class LLMConfig:
max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
input_cost_per_token: The cost per input token. This will available in logs for the user to check.
output_cost_per_token: The cost per output token. This will available in logs for the user to check.
cache_hit_discount: The discount to apply for cache hits (e.g., 0.8 for 80% discount).
cache_write_premium: The premium to apply for cache writes (e.g., 0.2 for 20% premium).
ollama_base_url: The base URL for the OLLAMA API.
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
Expand Down Expand Up @@ -70,6 +72,8 @@ class LLMConfig:
max_output_tokens: int | None = None
input_cost_per_token: float | None = None
output_cost_per_token: float | None = None
cache_hit_discount: float = 1 # Discount for cache hits (e.g., 0.8 for 80% discount)
cache_write_premium: float = 1 # Premium for cache writes (e.g., 0.2 for 20% premium)
ollama_base_url: str | None = None
drop_params: bool = True
disable_vision: bool | None = None
Expand Down
135 changes: 93 additions & 42 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,38 +439,34 @@ def _post_completion(self, response: ModelResponse) -> float:
usage: Usage | None = response.get('usage')

if usage:
# keep track of the input and output tokens
input_tokens = usage.get('prompt_tokens')
# Get total input tokens and output tokens
total_input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')

if input_tokens:
stats += 'Input tokens: ' + str(input_tokens)

if output_tokens:
stats += (
(' | ' if input_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ '\n'
)

# read the prompt cache hit, if any
# Get cache hits and writes
prompt_tokens_details: PromptTokensDetails = usage.get(
'prompt_tokens_details'
)
cache_hit_tokens = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else None
cache_hits = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else 0
)
if cache_hit_tokens:
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'

# For Anthropic, the cache writes have a different cost than regular input tokens
# but litellm doesn't separate them in the usage stats
# so we can read it from the provider-specific extra field
model_extra = usage.get('model_extra', {})
cache_write_tokens = model_extra.get('cache_creation_input_tokens')
if cache_write_tokens:
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
cache_writes = model_extra.get('cache_creation_input_tokens', 0)

# Calculate actual input tokens (excluding cache hits/writes)
input_tokens = (
total_input_tokens - (cache_hits + cache_writes)
if total_input_tokens
else 0
)

# Format stats
if total_input_tokens:
stats += f'Input tokens: {total_input_tokens}\n'
if cache_hits or cache_writes:
stats += f'Cache hits: {cache_hits} - cache writes: {cache_writes} - input: {input_tokens}\n'
if output_tokens:
stats += f'Output tokens: {output_tokens}\n'

# log the stats
if stats:
Expand Down Expand Up @@ -521,27 +517,82 @@ def _completion_cost(self, response) -> float:
if not self.cost_metric_supported:
return 0.0

extra_kwargs = {}
if (
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
):
cost_per_token = CostPerToken(
input_cost_per_token=self.config.input_cost_per_token,
output_cost_per_token=self.config.output_cost_per_token,
)
logger.debug(f'Using custom cost per token: {cost_per_token}')
extra_kwargs['custom_cost_per_token'] = cost_per_token

try:
# try directly get response_cost from response
cost = getattr(response, '_hidden_params', {}).get('response_cost', None)
if cost is None:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
if cost is not None:
self.metrics.add_cost(cost)
return cost

# Get usage details
usage: Usage | None = response.get('usage')
if not usage:
return 0.0

# Get token counts
total_input_tokens = usage.get('prompt_tokens', 0)
output_tokens = usage.get('completion_tokens', 0)

# Get cache details
prompt_tokens_details: PromptTokensDetails = usage.get(
'prompt_tokens_details'
)
cache_hits = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else 0
)
model_extra = usage.get('model_extra', {})
cache_writes = model_extra.get('cache_creation_input_tokens', 0)

# Calculate actual input tokens (excluding cache hits/writes)
input_tokens = (
total_input_tokens - (cache_hits + cache_writes)
if total_input_tokens
else 0
)

# Get cost per token configuration
if (
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
and self.config.cache_hit_discount is not None
and self.config.cache_write_premium is not None
enyst marked this conversation as resolved.
Show resolved Hide resolved
):
# Use custom pricing with configured cache discounts/premiums
input_cost = input_tokens * self.config.input_cost_per_token
output_cost = output_tokens * self.config.output_cost_per_token
cache_hit_cost = cache_hits * (
self.config.input_cost_per_token
* (1 - self.config.cache_hit_discount)
)
self.metrics.add_cost(cost)
return cost
cache_write_cost = 0
if 'anthropic' in self.config.model.lower():
cache_write_cost = cache_writes * (
self.config.input_cost_per_token
* (1 + self.config.cache_write_premium)
)
else:
# Use litellm's pricing with CostPerToken if custom costs are provided
custom_cost_per_token = None
if (
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
):
custom_cost_per_token = CostPerToken(
input_cost_per_token=self.config.input_cost_per_token,
output_cost_per_token=self.config.output_cost_per_token,
)

total_cost = litellm_completion_cost(
completion_response=response,
custom_cost_per_token=custom_cost_per_token,
)
self.metrics.add_cost(total_cost)
return total_cost

total_cost = input_cost + output_cost + cache_hit_cost + cache_write_cost
self.metrics.add_cost(total_cost)
return total_cost

except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
Expand Down
Loading