Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/caching #1922

Open
wants to merge 11 commits into
base: v2.6
Choose a base branch
from
12 changes: 8 additions & 4 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import dspy.retrievers

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
Expand All @@ -24,6 +24,10 @@
configure = settings.configure
context = settings.context

from dspy.utils.cache import DSPY_CACHE

cache = DSPY_CACHE


import dspy.teleprompt

Expand Down
2 changes: 1 addition & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dspy.adapters.utils import find_enum_member, format_field_value
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
from dspy.signatures import get_dspy_field_type

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")

Expand Down
2 changes: 1 addition & 1 deletion dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ..adapters.image_utils import Image
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
from ..signatures import get_dspy_field_type

_logger = logging.getLogger(__name__)

Expand Down
16 changes: 3 additions & 13 deletions dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import os
import litellm

from .lm import LM
from .provider import Provider, TrainingJob
from .base_lm import BaseLM, inspect_history
from .embedding import Embedder
import litellm
import os
from pathlib import Path
from litellm.caching import Cache

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default

# TODO: There's probably value in getting litellm to support FanoutCache and to separate the limit for
# the LM cache from the embeddings cache. Then we can lower the default 30GB limit.
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")

if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT:
litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT)

litellm.telemetry = False

Expand Down
137 changes: 22 additions & 115 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import functools
import logging
import os
import threading
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, List, Literal, Optional

import litellm
import pydantic
import ujson
from cachetools import LRUCache, cached

from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
Expand All @@ -20,6 +15,8 @@

from .base_lm import BaseLM

from dspy.utils.cache import cache_decorator

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -89,15 +86,18 @@ def __call__(self, prompt=None, messages=None, **kwargs):
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
wrapped_completion = completion

response = completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
if cache:
@cache_decorator(keep=litellm.Cache()._get_relevant_args_to_use_for_cache_key())
def cached_completion(**kwargs):
return completion(**kwargs, cache={"no-cache": False, "no-store": False})

wrapped_completion = cached_completion

response = wrapped_completion(model=self.model, messages=messages, num_retries=self.num_retries, **kwargs)

if kwargs.get("logprobs"):
outputs = [
{
Expand All @@ -109,7 +109,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]


# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand Down Expand Up @@ -229,121 +229,28 @@ def copy(self, **kwargs):
return new_instance


def request_cache(maxsize: Optional[int] = None):
"""
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).

Args:
maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).

Returns:
A decorator that wraps the target function with caching.
"""
def litellm_completion(cache={"no-cache": True, "no-store": True}, **kwargs):
return litellm.completion(cache=cache, **kwargs)

def cache_key(request: Dict[str, Any]) -> str:
"""
Obtain a unique cache key for the given request dictionary by hashing its JSON
representation. For request fields having types that are known to be JSON-incompatible,
convert them to a JSON-serializable format before hashing.

Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since
that would potentially lead to cache collisions. For example, consider request A
containing only JSON-convertible values and request B containing the same JSON-convertible
values in addition to one unconvertible value. Discarding the unconvertible value would
lead to a cache collision between requests A and B, even though they are semantically
different.
"""

def transform_value(value):
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
return value.schema()
elif isinstance(value, pydantic.BaseModel):
return value.dict()
elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
return value.__code__.co_code.decode("utf-8")
else:
# Note: We don't attempt to compute a hash of the value, since the default
# implementation of hash() is id(), which may collide if the same memory address
# is reused for different objects at different times
return value

params = {k: transform_value(v) for k, v in request.items()}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()

def decorator(func):
@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda key, request, *args, **kwargs: key,
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.RLock(),
)
def func_cached(key: str, request: Dict[str, Any], *args, **kwargs):
return func(request, *args, **kwargs)

@functools.wraps(func)
def wrapper(request: dict, *args, **kwargs):
try:
key = cache_key(request)
except Exception:
# If the cache key cannot be computed (e.g. because it contains a value that cannot
# be converted to JSON), bypass the cache and call the target function directly
return func(request, *args, **kwargs)
return func_cached(key, request, *args, **kwargs)

return wrapper

return decorator


@request_cache(maxsize=None)
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
**request,
)


@request_cache(maxsize=None)
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
return litellm_text_completion(
request,
num_retries=num_retries,
cache={"no-cache": False, "no-store": False},
)


def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
def litellm_text_completion(cache={"no-cache": True, "no-store": True}, **kwargs):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = request.pop("model").split("/", 1)
model = kwargs.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the request, or from the environment.
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(
cache=cache,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**request,
**kwargs,
)
2 changes: 2 additions & 0 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
backoff_time=10,
callbacks=[],
async_max_workers=8,
cache=None
)

# Global base configuration
Expand Down Expand Up @@ -54,6 +55,7 @@ def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py

return cls._instance

def __getattr__(self, name):
Expand Down
2 changes: 1 addition & 1 deletion dspy/signatures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .field import *
from .signature import *
from .signature import *
9 changes: 9 additions & 0 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, Tuple, Type, Union # noqa: UP035
import importlib

from typing import Literal
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

Expand Down Expand Up @@ -518,3 +519,11 @@ def infer_prefix(attribute_name: str) -> str:
title_cased_words.append(word.capitalize())

return " ".join(title_cased_words)



def get_dspy_field_type(field: FieldInfo) -> Literal["input", "output"]:
field_type = field.json_schema_extra.get("__dspy_field_type")
if field_type is None:
raise ValueError(f"Field {field} does not have a __dspy_field_type")
return field_type
10 changes: 0 additions & 10 deletions dspy/signatures/utils.py

This file was deleted.

Loading
Loading