From 47d653471b46c26f914355013c9d191395699d5c Mon Sep 17 00:00:00 2001 From: hmoazam Date: Sun, 8 Dec 2024 01:38:49 +0000 Subject: [PATCH 1/9] Rebased latest changes --- dspy/clients/__init__.py | 11 ++-- dspy/clients/cache.py | 129 +++++++++++++++++++++++++++++++++++++++ dspy/clients/lm.py | 81 ++---------------------- 3 files changed, 138 insertions(+), 83 deletions(-) create mode 100644 dspy/clients/cache.py diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 546a96c75..a1c30bcea 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -7,15 +7,12 @@ from pathlib import Path from litellm.caching import Cache -DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default -# TODO: There's probably value in getting litellm to support FanoutCache and to separate the limit for -# the LM cache from the embeddings cache. Then we can lower the default 30GB limit. -litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") +# litellm cache only used for embeddings +LITELLM_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".litellm_cache") -if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT: - litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT) +# Litellm cache is only used for embeddings +litellm.cache = Cache(disk_cache_dir=LITELLM_CACHE_DIR, type="disk") litellm.telemetry = False diff --git a/dspy/clients/cache.py b/dspy/clients/cache.py new file mode 100644 index 000000000..5948a0c49 --- /dev/null +++ b/dspy/clients/cache.py @@ -0,0 +1,129 @@ +import os +import pickle +import threading + +from diskcache import FanoutCache +from cachetools import LRUCache +import ujson +import pydantic +from hashlib import sha256 +from typing import Any, Dict, List, Literal, Optional +from functools import wraps +from pathlib import Path + + + + +CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +print(CACHE_DIR) +CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 1e10)) # 10 GB default +MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default +# # litellm cache only used for embeddings +# LITELLM_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".litellm_cache") + + +class DspyCache: + """ + dspy's caching interface. It provides 2 levels of caching (in the given order): + 1. An in memory cache - cachetools' lrucache + 2. A disk based cache - diskcache's fanoutcache + + The cache is threadsafe + + Args: + mem_size_limit: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded). + disk_size_limit: + """ + def __init__(self, directory, disk_size_limit, mem_size_limit): + self.memory_cache = LRUCache(maxsize=mem_size_limit) + self.fanout_cache = FanoutCache(shards=16, timeout=2, directory=directory, size_limit=disk_size_limit) + self.lock = threading.RLock() + + @staticmethod + def cache_key(request: Dict[str, Any]) -> str: + """ + Obtain a unique cache key for the given request dictionary by hashing its JSON + representation. For request fields having types that are known to be JSON-incompatible, + convert them to a JSON-serializable format before hashing. + + Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since + that would potentially lead to cache collisions. For example, consider request A + containing only JSON-convertible values and request B containing the same JSON-convertible + values in addition to one unconvertible value. Discarding the unconvertible value would + lead to a cache collision between requests A and B, even though they are semantically + different. + """ + + def transform_value(value): + if isinstance(value, type) and issubclass(value, pydantic.BaseModel): + return value.model_json_schema() # BaseModel.schema deprecated + elif isinstance(value, pydantic.BaseModel): + return value.model_dump() # BaseModel.dict deprecated + elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"): + return value.__code__.co_code.decode("utf-8") + else: + # Note: We don't attempt to compute a hash of the value, since the default + # implementation of hash() is id(), which may collide if the same memory address + # is reused for different objects at different times + return value + + params = {k: transform_value(v) for k, v in request.items()} + return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() + + def get(self, request: Dict[str, Any]) -> Any: + try: + # try/except in case can't compute key + key = self.cache_key(request) + with self.lock: + if key in self.memory_cache: + return self.memory_cache[key] + elif key in self.fanout_cache: + # found on disk but not in memory, add to memory cache + value = self.fanout_cache[key] + self.memory_cache[key] = value + return value + return None + except Exception: + return None + + def set(self, request: Dict[str, Any], value: Any) -> None: + try: + # try/except in case can't compute key + key = self.cache_key(request) + with self.lock: + self.memory_cache[key] = value + self.fanout_cache[key] = value + except Exception: + return None + + @classmethod + def load(cls, file, maxsize: int) -> "LRUCache": + pass + # return cls(pickle.load(file), maxsize) + + @staticmethod + def dump(obj, file) -> None: + pass + # pickle.dump([[k, v] for k, v in obj.items()], file) + + +# Initialize the cache +dspy_cache = DspyCache(directory=CACHE_DIR, disk_size_limit=CACHE_LIMIT, mem_size_limit=MEM_CACHE_LIMIT) + + +def dspy_cache_decorator(cache=dspy_cache): + def decorator(func): + @wraps(func) + def wrapper(request: dict, *args, **kwargs): + cached_result = cache.get(request) + if cached_result is not None: + return cached_result + + result = func(request, *args, **kwargs) + cache.set(request, result) + return result + + return wrapper + + return decorator + diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 349e65339..b593f43d1 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,4 +1,3 @@ -import functools import logging import os import threading @@ -20,6 +19,8 @@ from .base_lm import BaseLM +from .cache import dspy_cache_decorator + logger = logging.getLogger(__name__) @@ -219,81 +220,10 @@ def copy(self, **kwargs): return new_instance -def request_cache(maxsize: Optional[int] = None): - """ - A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept - a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring - good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow). - - Args: - maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded). - - Returns: - A decorator that wraps the target function with caching. - """ - - def cache_key(request: Dict[str, Any]) -> str: - """ - Obtain a unique cache key for the given request dictionary by hashing its JSON - representation. For request fields having types that are known to be JSON-incompatible, - convert them to a JSON-serializable format before hashing. - - Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since - that would potentially lead to cache collisions. For example, consider request A - containing only JSON-convertible values and request B containing the same JSON-convertible - values in addition to one unconvertible value. Discarding the unconvertible value would - lead to a cache collision between requests A and B, even though they are semantically - different. - """ - - def transform_value(value): - if isinstance(value, type) and issubclass(value, pydantic.BaseModel): - return value.schema() - elif isinstance(value, pydantic.BaseModel): - return value.dict() - elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"): - return value.__code__.co_code.decode("utf-8") - else: - # Note: We don't attempt to compute a hash of the value, since the default - # implementation of hash() is id(), which may collide if the same memory address - # is reused for different objects at different times - return value - - params = {k: transform_value(v) for k, v in request.items()} - return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() - - def decorator(func): - @cached( - # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead - cache=LRUCache(maxsize=maxsize or float("inf")), - key=lambda key, request, *args, **kwargs: key, - # Use a lock to ensure thread safety for the cache when DSPy LMs are queried - # concurrently, e.g. during optimization and evaluation - lock=threading.RLock(), - ) - def func_cached(key: str, request: Dict[str, Any], *args, **kwargs): - return func(request, *args, **kwargs) - - @functools.wraps(func) - def wrapper(request: dict, *args, **kwargs): - try: - key = cache_key(request) - return func_cached(key, request, *args, **kwargs) - except Exception: - # If the cache key cannot be computed (e.g. because it contains a value that cannot - # be converted to JSON), bypass the cache and call the target function directly - return func(request, *args, **kwargs) - - return wrapper - - return decorator - - -@request_cache(maxsize=None) +@dspy_cache_decorator() def cached_litellm_completion(request: Dict[str, Any], num_retries: int): return litellm_completion( request, - cache={"no-cache": False, "no-store": False}, num_retries=num_retries, ) @@ -306,12 +236,11 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac ) -@request_cache(maxsize=None) +@dspy_cache_decorator() def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): return litellm_text_completion( request, num_retries=num_retries, - cache={"no-cache": False, "no-store": False}, ) @@ -329,7 +258,7 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( - cache=cache, + # cache=cache, model=f"text-completion-openai/{model}", api_key=api_key, api_base=api_base, From 4a9ac55e1edfe3d1e429152d1df104dac7ee5489 Mon Sep 17 00:00:00 2001 From: hmoazam Date: Sun, 8 Dec 2024 20:51:28 +0000 Subject: [PATCH 2/9] WIP --- dspy/clients/cache.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dspy/clients/cache.py b/dspy/clients/cache.py index 5948a0c49..ef88d9e8a 100644 --- a/dspy/clients/cache.py +++ b/dspy/clients/cache.py @@ -12,15 +12,9 @@ from pathlib import Path - - CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -print(CACHE_DIR) CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 1e10)) # 10 GB default MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default -# # litellm cache only used for embeddings -# LITELLM_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".litellm_cache") - class DspyCache: """ From b5a948f8a54fccf6782ad71853d78da1842942cb Mon Sep 17 00:00:00 2001 From: hmoazam Date: Mon, 9 Dec 2024 00:52:31 +0000 Subject: [PATCH 3/9] WIP - initial tests working fine. TODO: how/where to expose save, load and reset methods, test with multiple threads --- dspy/clients/cache.py | 28 ++++++++++++++++++---------- dspy/clients/lm.py | 3 +-- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/dspy/clients/cache.py b/dspy/clients/cache.py index ef88d9e8a..d846f5328 100644 --- a/dspy/clients/cache.py +++ b/dspy/clients/cache.py @@ -90,16 +90,24 @@ def set(self, request: Dict[str, Any], value: Any) -> None: except Exception: return None - @classmethod - def load(cls, file, maxsize: int) -> "LRUCache": - pass - # return cls(pickle.load(file), maxsize) - - @staticmethod - def dump(obj, file) -> None: - pass - # pickle.dump([[k, v] for k, v in obj.items()], file) - + def load(self, file_path, maxsize: int) -> "LRUCache": + with open(file_path, "rb") as f: + cache_items = pickle.load(f) + + with self.lock: + self.memory_cache.clear() + for k,v in cache_items: + self.memory_cache[k] = v + + def save(self, file_path: str) -> None: + with self.lock: + with open(file_path, "wb") as f: + cache_items = list(self.memory_cache.items()) + pickle.dump(cache_items, f) + + def reset_memory_cache(self) -> None: + with self.lock: + self.memory_cache.clear() # Initialize the cache dspy_cache = DspyCache(directory=CACHE_DIR, disk_size_limit=CACHE_LIMIT, mem_size_limit=MEM_CACHE_LIMIT) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index b593f43d1..174614d9b 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -244,7 +244,7 @@ def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): ) -def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): +def litellm_text_completion(request: Dict[str, Any], num_retries: int): # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" model = request.pop("model").split("/", 1) @@ -258,7 +258,6 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( - # cache=cache, model=f"text-completion-openai/{model}", api_key=api_key, api_base=api_base, From dfdf8f7e6385c6e47240cc48fde8e37a1951e302 Mon Sep 17 00:00:00 2001 From: hmoazam Date: Tue, 10 Dec 2024 21:19:36 +0000 Subject: [PATCH 4/9] Cleaned up structure and added cache to settings --- dsp/utils/settings.py | 1 + dspy/__init__.py | 9 ++++++++ dspy/clients/__init__.py | 18 +++++++++++---- dspy/clients/lm.py | 31 ++++++++++++++----------- dspy/{clients => utils}/cache.py | 39 +++++++++++++------------------- 5 files changed, 57 insertions(+), 41 deletions(-) rename dspy/{clients => utils}/cache.py (82%) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 16ae6f93b..299211bd3 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -24,6 +24,7 @@ backoff_time=10, callbacks=[], async_max_workers=8, + cache=None ) # Global base configuration diff --git a/dspy/__init__.py b/dspy/__init__.py index 9e3e85fd2..9895d4f45 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -15,11 +15,13 @@ from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify +from dspy.utils.cache import Cache settings = dsp.settings configure_dspy_loggers(__name__) + # LM = dsp.LM AzureOpenAI = dsp.AzureOpenAI @@ -64,6 +66,13 @@ configure = settings.configure context = settings.context +CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 1e10)) # 10 GB default +MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default + +# Initialize the cache +dspy_cache = Cache(directory=CACHE_DIR, disk_size_limit=CACHE_LIMIT, mem_size_limit=MEM_CACHE_LIMIT) +settings.configure(cache=dspy_cache) import dspy.teleprompt diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index a1c30bcea..6dc438403 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -5,14 +5,22 @@ import litellm import os from pathlib import Path -from litellm.caching import Cache +from litellm.caching import Cache as litellm_cache -# litellm cache only used for embeddings -LITELLM_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".litellm_cache") +# ------------------------------ LiteLLM caching ----------------------------------- +DISK_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".dspy_litellm_cache") +DISK_CACHE_LIMIT = int(os.environ.get("DSPY_LITELLM_CACHE_LIMIT", 3e10)) # 30 GB default + +# TODO: There's probably value in separating the limit for +# the LM cache from the embeddings cache. Then we can lower the default 30GB limit. +litellm.cache = litellm_cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") + +if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT: + litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT) + +# ---------------------------------------------------------------------------------- -# Litellm cache is only used for embeddings -litellm.cache = Cache(disk_cache_dir=LITELLM_CACHE_DIR, type="disk") litellm.telemetry = False diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 174614d9b..e8ce26429 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -19,7 +19,7 @@ from .base_lm import BaseLM -from .cache import dspy_cache_decorator +from dspy.utils.cache import dspy_cache_decorator logger = logging.getLogger(__name__) @@ -219,13 +219,16 @@ def copy(self, **kwargs): return new_instance - -@dspy_cache_decorator() def cached_litellm_completion(request: Dict[str, Any], num_retries: int): - return litellm_completion( - request, - num_retries=num_retries, - ) + from dspy import settings + @dspy_cache_decorator(settings.cache) + def cached_litellm_completion_inner(request: Dict[str, Any], num_retries: int): + return litellm_completion( + request, + cache={"no-cache": False, "no-store": False}, + num_retries=num_retries, + ) + return cached_litellm_completion_inner(request, num_retries) def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): @@ -235,13 +238,15 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac **request, ) - -@dspy_cache_decorator() def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): - return litellm_text_completion( - request, - num_retries=num_retries, - ) + @dspy_cache_decorator() + def cached_litellm_text_completion_inner(request: Dict[str, Any], num_retries: int): + return litellm_text_completion( + request, + cache={"no-cache": False, "no-store": False}, + num_retries=num_retries, + ) + return cached_litellm_text_completion_inner(request, num_retries) def litellm_text_completion(request: Dict[str, Any], num_retries: int): diff --git a/dspy/clients/cache.py b/dspy/utils/cache.py similarity index 82% rename from dspy/clients/cache.py rename to dspy/utils/cache.py index d846f5328..c34110faa 100644 --- a/dspy/clients/cache.py +++ b/dspy/utils/cache.py @@ -1,7 +1,8 @@ import os import pickle import threading - +import litellm +from litellm.caching import Cache as litellm_cache from diskcache import FanoutCache from cachetools import LRUCache import ujson @@ -11,12 +12,7 @@ from functools import wraps from pathlib import Path - -CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 1e10)) # 10 GB default -MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default - -class DspyCache: +class Cache: """ dspy's caching interface. It provides 2 levels of caching (in the given order): 1. An in memory cache - cachetools' lrucache @@ -68,27 +64,27 @@ def get(self, request: Dict[str, Any]) -> Any: try: # try/except in case can't compute key key = self.cache_key(request) - with self.lock: - if key in self.memory_cache: - return self.memory_cache[key] - elif key in self.fanout_cache: - # found on disk but not in memory, add to memory cache - value = self.fanout_cache[key] - self.memory_cache[key] = value - return value - return None except Exception: return None + with self.lock: + if key in self.memory_cache: + return self.memory_cache[key] + elif key in self.fanout_cache: + # found on disk but not in memory, add to memory cache + value = self.fanout_cache[key] + self.memory_cache[key] = value + return value + return None def set(self, request: Dict[str, Any], value: Any) -> None: try: # try/except in case can't compute key key = self.cache_key(request) - with self.lock: - self.memory_cache[key] = value - self.fanout_cache[key] = value except Exception: return None + with self.lock: + self.memory_cache[key] = value + self.fanout_cache[key] = value def load(self, file_path, maxsize: int) -> "LRUCache": with open(file_path, "rb") as f: @@ -109,11 +105,8 @@ def reset_memory_cache(self) -> None: with self.lock: self.memory_cache.clear() -# Initialize the cache -dspy_cache = DspyCache(directory=CACHE_DIR, disk_size_limit=CACHE_LIMIT, mem_size_limit=MEM_CACHE_LIMIT) - -def dspy_cache_decorator(cache=dspy_cache): +def dspy_cache_decorator(cache): def decorator(func): @wraps(func) def wrapper(request: dict, *args, **kwargs): From 385014d9bac9897df8172c3a77512b7b89015f00 Mon Sep 17 00:00:00 2001 From: hmoazam Date: Tue, 10 Dec 2024 22:11:45 +0000 Subject: [PATCH 5/9] looks like it works --- dspy/utils/cache.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dspy/utils/cache.py b/dspy/utils/cache.py index c34110faa..3807eb814 100644 --- a/dspy/utils/cache.py +++ b/dspy/utils/cache.py @@ -83,10 +83,10 @@ def set(self, request: Dict[str, Any], value: Any) -> None: except Exception: return None with self.lock: - self.memory_cache[key] = value - self.fanout_cache[key] = value + self.memory_cache[key] = value + self.fanout_cache[key] = value - def load(self, file_path, maxsize: int) -> "LRUCache": + def load(self, file_path, maxsize:float=float("inf")) -> "LRUCache": with open(file_path, "rb") as f: cache_items = pickle.load(f) @@ -105,7 +105,6 @@ def reset_memory_cache(self) -> None: with self.lock: self.memory_cache.clear() - def dspy_cache_decorator(cache): def decorator(func): @wraps(func) From afd025afe579d43514e5a6d92e43dadf5ce40af5 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Wed, 11 Dec 2024 09:28:59 -0800 Subject: [PATCH 6/9] Some WIP adjustments (2 TODOs marked in cache.py!) --- dspy/__init__.py | 19 +++++++++++-------- dspy/clients/__init__.py | 8 ++------ dspy/clients/lm.py | 38 +++++++++++++++----------------------- dspy/utils/cache.py | 34 ++++++++++++++++++---------------- 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/dspy/__init__.py b/dspy/__init__.py index ff62b7ad4..84f84091b 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -9,10 +9,10 @@ import dspy.retrievers # Functional must be imported after primitives, predict and signatures -from .functional import * # isort: skip -from dspy.evaluate import Evaluate # isort: skip -from dspy.clients import * # isort: skip -from dspy.adapters import * # isort: skip +from .functional import * # isort: skip +from dspy.evaluate import Evaluate # isort: skip +from dspy.clients import * # isort: skip +from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.cache import Cache @@ -67,12 +67,15 @@ configure = settings.configure context = settings.context -CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 1e10)) # 10 GB default -MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default +from dspy.clients import DISK_CACHE_DIR, DISK_CACHE_LIMIT +MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default # Initialize the cache -dspy_cache = Cache(directory=CACHE_DIR, disk_size_limit=CACHE_LIMIT, mem_size_limit=MEM_CACHE_LIMIT) +dspy_cache = Cache( + directory=os.path.join(DISK_CACHE_DIR, ".cache_v2_6"), + disk_size_limit=DISK_CACHE_LIMIT, + mem_size_limit=MEM_CACHE_LIMIT +) settings.configure(cache=dspy_cache) import dspy.teleprompt diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 6dc438403..22fb0acc7 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -8,9 +8,8 @@ from litellm.caching import Cache as litellm_cache -# ------------------------------ LiteLLM caching ----------------------------------- -DISK_CACHE_DIR = os.environ.get("DSPY_LITELLM_CACHEDIR") or os.path.join(Path.home(), ".dspy_litellm_cache") -DISK_CACHE_LIMIT = int(os.environ.get("DSPY_LITELLM_CACHE_LIMIT", 3e10)) # 30 GB default +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default # TODO: There's probably value in separating the limit for # the LM cache from the embeddings cache. Then we can lower the default 30GB limit. @@ -19,9 +18,6 @@ if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT: litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT) -# ---------------------------------------------------------------------------------- - - litellm.telemetry = False # Turn off by default to avoid LiteLLM logging during every LM call. diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 71e35ecbb..076c8bbaa 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -3,13 +3,9 @@ import threading import uuid from datetime import datetime -from hashlib import sha256 from typing import Any, Dict, List, Literal, Optional import litellm -import pydantic -import ujson -from cachetools import LRUCache, cached from dspy.adapters.base import Adapter from dspy.clients.openai import OpenAIProvider @@ -19,7 +15,7 @@ from .base_lm import BaseLM -from dspy.utils.cache import dspy_cache_decorator +from dspy.utils.cache import cache_decorator logger = logging.getLogger(__name__) @@ -229,16 +225,13 @@ def copy(self, **kwargs): return new_instance +@cache_decorator() def cached_litellm_completion(request: Dict[str, Any], num_retries: int): - from dspy import settings - @dspy_cache_decorator(settings.cache) - def cached_litellm_completion_inner(request: Dict[str, Any], num_retries: int): - return litellm_completion( - request, - cache={"no-cache": False, "no-store": False}, - num_retries=num_retries, - ) - return cached_litellm_completion_inner(request, num_retries) + return litellm_completion( + request, + cache={"no-cache": False, "no-store": False}, + num_retries=num_retries, + ) def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): @@ -248,18 +241,16 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac **request, ) +@cache_decorator() def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): - @dspy_cache_decorator() - def cached_litellm_text_completion_inner(request: Dict[str, Any], num_retries: int): - return litellm_text_completion( - request, - cache={"no-cache": False, "no-store": False}, - num_retries=num_retries, - ) - return cached_litellm_text_completion_inner(request, num_retries) + return litellm_text_completion( + request, + num_retries=num_retries, + cache={"no-cache": False, "no-store": False}, + ) -def litellm_text_completion(request: Dict[str, Any], num_retries: int): +def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" model = request.pop("model").split("/", 1) @@ -273,6 +264,7 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int): prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( + cache=cache, model=f"text-completion-openai/{model}", api_key=api_key, api_base=api_base, diff --git a/dspy/utils/cache.py b/dspy/utils/cache.py index 3807eb814..bf8add6bf 100644 --- a/dspy/utils/cache.py +++ b/dspy/utils/cache.py @@ -1,16 +1,12 @@ -import os import pickle import threading -import litellm -from litellm.caching import Cache as litellm_cache from diskcache import FanoutCache from cachetools import LRUCache import ujson import pydantic from hashlib import sha256 -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict from functools import wraps -from pathlib import Path class Cache: """ @@ -62,26 +58,26 @@ def transform_value(value): def get(self, request: Dict[str, Any]) -> Any: try: - # try/except in case can't compute key key = self.cache_key(request) except Exception: return None - with self.lock: + + with self.lock: # TODO: Do we need this lock for reads? LRUCache is ambiguous! if key in self.memory_cache: - return self.memory_cache[key] - elif key in self.fanout_cache: + return self.memory_cache[key] + + if key in self.fanout_cache: # found on disk but not in memory, add to memory cache value = self.fanout_cache[key] self.memory_cache[key] = value return value - return None def set(self, request: Dict[str, Any], value: Any) -> None: try: - # try/except in case can't compute key key = self.cache_key(request) except Exception: return None + with self.lock: self.memory_cache[key] = value self.fanout_cache[key] = value @@ -91,21 +87,27 @@ def load(self, file_path, maxsize:float=float("inf")) -> "LRUCache": cache_items = pickle.load(f) with self.lock: - self.memory_cache.clear() + # self.memory_cache.clear() for k,v in cache_items: self.memory_cache[k] = v def save(self, file_path: str) -> None: with self.lock: - with open(file_path, "wb") as f: - cache_items = list(self.memory_cache.items()) - pickle.dump(cache_items, f) + cache_items = list(self.memory_cache.items()) + + with open(file_path, "wb") as f: + pickle.dump(cache_items, f) def reset_memory_cache(self) -> None: with self.lock: self.memory_cache.clear() -def dspy_cache_decorator(cache): +def cache_decorator(): + import dspy + cache = dspy.settings.cache + + # TODO: FIXME: The name of the decorated function should be part of the cache key + def decorator(func): @wraps(func) def wrapper(request: dict, *args, **kwargs): From e712a61da3c1228f7e39bab636fb879275671c1d Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Thu, 12 Dec 2024 13:49:38 -0800 Subject: [PATCH 7/9] DSPy cache decorator: use function identifier --- dspy/__init__.py | 2 +- dspy/utils/cache.py | 72 ++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/dspy/__init__.py b/dspy/__init__.py index 84f84091b..c7d0e26f1 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -15,7 +15,6 @@ from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify -from dspy.utils.cache import Cache from dspy.utils.saving import load settings = dsp.settings @@ -67,6 +66,7 @@ configure = settings.configure context = settings.context +from dspy.utils.cache import Cache from dspy.clients import DISK_CACHE_DIR, DISK_CACHE_LIMIT MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default diff --git a/dspy/utils/cache.py b/dspy/utils/cache.py index bf8add6bf..56706889e 100644 --- a/dspy/utils/cache.py +++ b/dspy/utils/cache.py @@ -1,26 +1,30 @@ +import ujson import pickle +import pydantic import threading + from diskcache import FanoutCache from cachetools import LRUCache -import ujson -import pydantic from hashlib import sha256 from typing import Any, Dict from functools import wraps + class Cache: """ - dspy's caching interface. It provides 2 levels of caching (in the given order): - 1. An in memory cache - cachetools' lrucache + DSPy's caching interface. It provides 2 levels of caching (in the given order): + 1. An in memory cache - cachetools' lrucache 2. A disk based cache - diskcache's fanoutcache - - The cache is threadsafe - - Args: - mem_size_limit: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded). - disk_size_limit: """ + def __init__(self, directory, disk_size_limit, mem_size_limit): + """ + Args: + directory: The directory where the disk cache is stored. + disk_size_limit: The maximum size of the disk cache (in bytes). + mem_size_limit: The maximum size of the in-memory cache (in number of items). + """ + self.memory_cache = LRUCache(maxsize=mem_size_limit) self.fanout_cache = FanoutCache(shards=16, timeout=2, directory=directory, size_limit=disk_size_limit) self.lock = threading.RLock() @@ -31,26 +35,16 @@ def cache_key(request: Dict[str, Any]) -> str: Obtain a unique cache key for the given request dictionary by hashing its JSON representation. For request fields having types that are known to be JSON-incompatible, convert them to a JSON-serializable format before hashing. - - Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since - that would potentially lead to cache collisions. For example, consider request A - containing only JSON-convertible values and request B containing the same JSON-convertible - values in addition to one unconvertible value. Discarding the unconvertible value would - lead to a cache collision between requests A and B, even though they are semantically - different. """ - def transform_value(value): if isinstance(value, type) and issubclass(value, pydantic.BaseModel): return value.model_json_schema() # BaseModel.schema deprecated elif isinstance(value, pydantic.BaseModel): return value.model_dump() # BaseModel.dict deprecated elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"): + # Represent callable code objects as string return value.__code__.co_code.decode("utf-8") else: - # Note: We don't attempt to compute a hash of the value, since the default - # implementation of hash() is id(), which may collide if the same memory address - # is reused for different objects at different times return value params = {k: transform_value(v) for k, v in request.items()} @@ -62,7 +56,7 @@ def get(self, request: Dict[str, Any]) -> Any: except Exception: return None - with self.lock: # TODO: Do we need this lock for reads? LRUCache is ambiguous! + with self.lock: # lock for thread safety (low overhead) if key in self.memory_cache: return self.memory_cache[key] @@ -82,13 +76,12 @@ def set(self, request: Dict[str, Any], value: Any) -> None: self.memory_cache[key] = value self.fanout_cache[key] = value - def load(self, file_path, maxsize:float=float("inf")) -> "LRUCache": + def load(self, file_path: str): with open(file_path, "rb") as f: cache_items = pickle.load(f) with self.lock: - # self.memory_cache.clear() - for k,v in cache_items: + for k, v in cache_items: self.memory_cache[k] = v def save(self, file_path: str) -> None: @@ -102,24 +95,31 @@ def reset_memory_cache(self) -> None: with self.lock: self.memory_cache.clear() -def cache_decorator(): - import dspy - cache = dspy.settings.cache - - # TODO: FIXME: The name of the decorated function should be part of the cache key +def cache_decorator(): def decorator(func): @wraps(func) def wrapper(request: dict, *args, **kwargs): - cached_result = cache.get(request) + import dspy + cache = dspy.settings.cache + + # Use fully qualified function name for uniqueness + func_identifier = f"{func.__module__}.{func.__qualname__}" + + # Create a modified request that includes the function identifier + # so that it's incorporated into the cache key. + modified_request = dict(request) + modified_request["_func_identifier"] = func_identifier + + # Retrieve from cache if available + cached_result = cache.get(modified_request) if cached_result is not None: return cached_result - + + # Otherwise, compute and store the result result = func(request, *args, **kwargs) - cache.set(request, result) + cache.set(modified_request, result) return result - + return wrapper - return decorator - From 4c6b529fc54cf48c5837020f464db6bb9435d8ed Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Thu, 12 Dec 2024 18:38:53 -0800 Subject: [PATCH 8/9] Fixes --- dsp/utils/settings.py | 1 + dspy/adapters/base.py | 23 ++++++++++------- dspy/clients/lm.py | 58 +++++++++++++++---------------------------- dspy/utils/cache.py | 14 ++++++++--- 4 files changed, 45 insertions(+), 51 deletions(-) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 299211bd3..dacc65cfe 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -55,6 +55,7 @@ def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py + return cls._instance def __getattr__(self, name): diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index fed1e861c..d472ddda3 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -22,16 +22,21 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): try: for output in outputs: - if type(output) is str: - output_text, output_logprobs = output, None - elif type(output) is dict: - output_text, output_logprobs = output["text"], output["logprobs"] - else: - raise ValueError(f"Expected str or dict but got {type(output)}") - value = self.parse(signature, output_text, _parse_values=_parse_values) - assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}" - value["logprobs"] = output_logprobs + output_logprobs = None + + if isinstance(output, dict): + output, output_logprobs = output["text"], output["logprobs"] + + value = self.parse(signature, output, _parse_values=_parse_values) + + assert set(value.keys()) == set(signature.output_fields.keys()), \ + f"Expected {signature.output_fields.keys()} but got {value.keys()}" + + if output_logprobs is not None: + value["logprobs"] = output_logprobs + values.append(value) + return values except Exception as e: diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 076c8bbaa..05a839d1f 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -86,15 +86,18 @@ def __call__(self, prompt=None, messages=None, **kwargs): kwargs = {**self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. - if self.model_type == "chat": - completion = cached_litellm_completion if cache else litellm_completion - else: - completion = cached_litellm_text_completion if cache else litellm_text_completion + completion = litellm_completion if self.model_type == "chat" else litellm_text_completion + wrapped_completion = completion - response = completion( - request=dict(model=self.model, messages=messages, **kwargs), - num_retries=self.num_retries, - ) + if cache: + @cache_decorator(keep=litellm.Cache()._get_relevant_args_to_use_for_cache_key()) + def cached_completion(**kwargs): + return completion(**kwargs, cache={"no-cache": False, "no-store": False}) + + wrapped_completion = cached_completion + + response = wrapped_completion(model=self.model, messages=messages, num_retries=self.num_retries, **kwargs) + if kwargs.get("logprobs"): outputs = [ { @@ -106,7 +109,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): else: outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] - + # Logging, with removed api key & where `cost` is None on cache hit. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) @@ -225,43 +228,23 @@ def copy(self, **kwargs): return new_instance -@cache_decorator() -def cached_litellm_completion(request: Dict[str, Any], num_retries: int): - return litellm_completion( - request, - cache={"no-cache": False, "no-store": False}, - num_retries=num_retries, - ) - -def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): - return litellm.completion( - num_retries=num_retries, - cache=cache, - **request, - ) - -@cache_decorator() -def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): - return litellm_text_completion( - request, - num_retries=num_retries, - cache={"no-cache": False, "no-store": False}, - ) +def litellm_completion(cache={"no-cache": True, "no-store": True}, **kwargs): + return litellm.completion(cache=cache, **kwargs) -def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): +def litellm_text_completion(cache={"no-cache": True, "no-store": True}, **kwargs): # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" - model = request.pop("model").split("/", 1) + model = kwargs.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] # Use the API key and base from the request, or from the environment. - api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") - api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. - prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) + prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( cache=cache, @@ -269,6 +252,5 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n api_key=api_key, api_base=api_base, prompt=prompt, - num_retries=num_retries, - **request, + **kwargs, ) diff --git a/dspy/utils/cache.py b/dspy/utils/cache.py index 56706889e..17efa8b0b 100644 --- a/dspy/utils/cache.py +++ b/dspy/utils/cache.py @@ -96,10 +96,10 @@ def reset_memory_cache(self) -> None: self.memory_cache.clear() -def cache_decorator(): +def cache_decorator(ignore=None, keep=None): def decorator(func): @wraps(func) - def wrapper(request: dict, *args, **kwargs): + def wrapper(**kwargs): import dspy cache = dspy.settings.cache @@ -108,16 +108,22 @@ def wrapper(request: dict, *args, **kwargs): # Create a modified request that includes the function identifier # so that it's incorporated into the cache key. - modified_request = dict(request) + modified_request = dict(kwargs) modified_request["_func_identifier"] = func_identifier + for key in list(modified_request.keys()): + if ignore and key in ignore: + del modified_request[key] + if keep and key not in keep: + del modified_request[key] + # Retrieve from cache if available cached_result = cache.get(modified_request) if cached_result is not None: return cached_result # Otherwise, compute and store the result - result = func(request, *args, **kwargs) + result = func(**kwargs) cache.set(modified_request, result) return result From 9f3b0900b609dfa9469b2bbc9ef0df7dea2aac06 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Fri, 13 Dec 2024 11:23:07 -0800 Subject: [PATCH 9/9] WIP debugging the cache tests --- dspy/__init__.py | 17 +--- dspy/adapters/chat_adapter.py | 2 +- dspy/adapters/json_adapter.py | 2 +- dspy/clients/__init__.py | 17 +--- dspy/signatures/__init__.py | 2 +- dspy/signatures/signature.py | 9 ++ dspy/signatures/utils.py | 10 -- dspy/utils/cache.py | 33 ++++++- tests/caching/test_caching.py | 180 +++++++++++++++++----------------- 9 files changed, 140 insertions(+), 132 deletions(-) delete mode 100644 dspy/signatures/utils.py diff --git a/dspy/__init__.py b/dspy/__init__.py index a7e262d3a..c39585cd5 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -9,7 +9,7 @@ from .functional import * # isort: skip from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import * # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.saving import load @@ -24,17 +24,10 @@ configure = settings.configure context = settings.context -from dspy.utils.cache import Cache -from dspy.clients import DISK_CACHE_DIR, DISK_CACHE_LIMIT -MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default - -# Initialize the cache -dspy_cache = Cache( - directory=os.path.join(DISK_CACHE_DIR, ".cache_v2_6"), - disk_size_limit=DISK_CACHE_LIMIT, - mem_size_limit=MEM_CACHE_LIMIT -) -settings.configure(cache=dspy_cache) +from dspy.utils.cache import DSPY_CACHE + +cache = DSPY_CACHE + import dspy.teleprompt diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f92ee6595..b4535546c 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -16,7 +16,7 @@ from dspy.adapters.utils import find_enum_member, format_field_value from dspy.signatures.field import OutputField from dspy.signatures.signature import Signature, SignatureMeta -from dspy.signatures.utils import get_dspy_field_type +from dspy.signatures import get_dspy_field_type field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 281df5cb4..dd9bf2629 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -18,7 +18,7 @@ from ..adapters.image_utils import Image from ..signatures.signature import SignatureMeta -from ..signatures.utils import get_dspy_field_type +from ..signatures import get_dspy_field_type _logger = logging.getLogger(__name__) diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 22fb0acc7..a73864737 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,22 +1,11 @@ +import os +import litellm + from .lm import LM from .provider import Provider, TrainingJob from .base_lm import BaseLM, inspect_history from .embedding import Embedder -import litellm -import os -from pathlib import Path -from litellm.caching import Cache as litellm_cache - - -DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default - -# TODO: There's probably value in separating the limit for -# the LM cache from the embeddings cache. Then we can lower the default 30GB limit. -litellm.cache = litellm_cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") -if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT: - litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT) litellm.telemetry = False diff --git a/dspy/signatures/__init__.py b/dspy/signatures/__init__.py index ba4637c83..60cc8cedd 100644 --- a/dspy/signatures/__init__.py +++ b/dspy/signatures/__init__.py @@ -1,2 +1,2 @@ from .field import * -from .signature import * +from .signature import * \ No newline at end of file diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index bd7b35a86..46e10b44f 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -9,6 +9,7 @@ from typing import Any, Dict, Tuple, Type, Union # noqa: UP035 import importlib +from typing import Literal from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo @@ -518,3 +519,11 @@ def infer_prefix(attribute_name: str) -> str: title_cased_words.append(word.capitalize()) return " ".join(title_cased_words) + + + +def get_dspy_field_type(field: FieldInfo) -> Literal["input", "output"]: + field_type = field.json_schema_extra.get("__dspy_field_type") + if field_type is None: + raise ValueError(f"Field {field} does not have a __dspy_field_type") + return field_type diff --git a/dspy/signatures/utils.py b/dspy/signatures/utils.py deleted file mode 100644 index 9f43e35da..000000000 --- a/dspy/signatures/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Literal - -from pydantic.fields import FieldInfo - - -def get_dspy_field_type(field: FieldInfo) -> Literal["input", "output"]: - field_type = field.json_schema_extra.get("__dspy_field_type") - if field_type is None: - raise ValueError(f"Field {field} does not have a __dspy_field_type") - return field_type diff --git a/dspy/utils/cache.py b/dspy/utils/cache.py index 17efa8b0b..d58b141af 100644 --- a/dspy/utils/cache.py +++ b/dspy/utils/cache.py @@ -1,15 +1,30 @@ +import os import ujson import pickle +import litellm import pydantic import threading -from diskcache import FanoutCache -from cachetools import LRUCache +from pathlib import Path from hashlib import sha256 -from typing import Any, Dict from functools import wraps +from typing import Any, Dict +from cachetools import LRUCache +from diskcache import FanoutCache +from litellm.caching import Cache as litellm_cache +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default +MEM_CACHE_LIMIT = float(os.environ.get("DSPY_CACHE_LIMIT", float("inf"))) # unlimited by default + +# TODO: There's probably value in separating the limit for +# the LM cache from the embeddings cache. Then we can lower the default 30GB limit. +litellm.cache = litellm_cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") + +if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT: + litellm.cache.cache.disk_cache.reset('size_limit', DISK_CACHE_LIMIT) + class Cache: """ DSPy's caching interface. It provides 2 levels of caching (in the given order): @@ -74,6 +89,8 @@ def set(self, request: Dict[str, Any], value: Any) -> None: with self.lock: self.memory_cache[key] = value + print(f"Setting cache key: {key}") + print(f"Setting cache value: {value}") self.fanout_cache[key] = value def load(self, file_path: str): @@ -101,7 +118,7 @@ def decorator(func): @wraps(func) def wrapper(**kwargs): import dspy - cache = dspy.settings.cache + cache = dspy.cache # Use fully qualified function name for uniqueness func_identifier = f"{func.__module__}.{func.__qualname__}" @@ -129,3 +146,11 @@ def wrapper(**kwargs): return wrapper return decorator + + +# Initialize the cache +DSPY_CACHE = Cache( + directory=os.path.join(DISK_CACHE_DIR, ".cache_v2_6"), + disk_size_limit=DISK_CACHE_LIMIT, + mem_size_limit=MEM_CACHE_LIMIT +) diff --git a/tests/caching/test_caching.py b/tests/caching/test_caching.py index f890dfad3..9a1a0a2ad 100644 --- a/tests/caching/test_caching.py +++ b/tests/caching/test_caching.py @@ -14,6 +14,7 @@ def temporary_blank_cache_dir(monkeypatch): with tempfile.TemporaryDirectory() as cache_dir_path: monkeypatch.setenv("DSPY_CACHEDIR", cache_dir_path) + importlib.reload(dspy.utils.cache) importlib.reload(dspy.clients) yield cache_dir_path @@ -30,47 +31,48 @@ def temporary_populated_cache_dir(monkeypatch): with tempfile.TemporaryDirectory() as cache_dir_path: shutil.copytree(populated_cache_path, cache_dir_path, dirs_exist_ok=True) monkeypatch.setenv("DSPY_CACHEDIR", cache_dir_path) + importlib.reload(dspy.utils.cache) importlib.reload(dspy.clients) yield cache_dir_path -def test_lm_calls_are_cached_across_lm_instances(litellm_test_server, temporary_blank_cache_dir): - api_base, server_log_file_path = litellm_test_server - - # Call 2 LM instances with the same model & text and verify that only one API request is sent - # to the LiteLLM server - lm1 = dspy.LM( - model="openai/dspy-test-model", - api_base=api_base, - api_key="fakekey", - ) - lm1("Example query") - lm2 = dspy.LM( - model="openai/dspy-test-model", - api_base=api_base, - api_key="fakekey", - ) - lm2("Example query") - request_logs = read_litellm_test_server_request_logs(server_log_file_path) - assert len(request_logs) == 1 - - # Call one of the LMs with new text and verify that a new API request is sent to the - # LiteLLM server - lm1("New query") - request_logs = read_litellm_test_server_request_logs(server_log_file_path) - assert len(request_logs) == 2 - - # Create a new LM instance with a different model and query it twice with the original text. - # Verify that one new API request is sent to the LiteLLM server - lm3 = dspy.LM( - model="openai/dspy-test-model-2", - api_base=api_base, - api_key="fakekey", - ) - lm3("Example query") - lm3("Example query") - request_logs = read_litellm_test_server_request_logs(server_log_file_path) - assert len(request_logs) == 3 +# def test_lm_calls_are_cached_across_lm_instances(litellm_test_server, temporary_blank_cache_dir): +# api_base, server_log_file_path = litellm_test_server + +# # Call 2 LM instances with the same model & text and verify that only one API request is sent +# # to the LiteLLM server +# lm1 = dspy.LM( +# model="openai/dspy-test-model", +# api_base=api_base, +# api_key="fakekey", +# ) +# lm1("Example query") +# lm2 = dspy.LM( +# model="openai/dspy-test-model", +# api_base=api_base, +# api_key="fakekey", +# ) +# lm2("Example query") +# request_logs = read_litellm_test_server_request_logs(server_log_file_path) +# assert len(request_logs) == 1 + +# # Call one of the LMs with new text and verify that a new API request is sent to the +# # LiteLLM server +# lm1("New query") +# request_logs = read_litellm_test_server_request_logs(server_log_file_path) +# assert len(request_logs) == 2 + +# # Create a new LM instance with a different model and query it twice with the original text. +# # Verify that one new API request is sent to the LiteLLM server +# lm3 = dspy.LM( +# model="openai/dspy-test-model-2", +# api_base=api_base, +# api_key="fakekey", +# ) +# lm3("Example query") +# lm3("Example query") +# request_logs = read_litellm_test_server_request_logs(server_log_file_path) +# assert len(request_logs) == 3 def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, temporary_populated_cache_dir): @@ -91,24 +93,24 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te assert len(request_logs) == 0 -def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir): - api_base, server_log_file_path = litellm_test_server +# def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir): +# api_base, server_log_file_path = litellm_test_server - lm1 = dspy.LM( - model="openai/dspy-test-model", - api_base=api_base, - api_key="fakekey", - ) - lm1("Example query") - # Remove the disk cache, after which the LM must rely on in-memory caching - shutil.rmtree(temporary_blank_cache_dir) - lm1("Example query2") - lm1("Example query2") - lm1("Example query2") - lm1("Example query2") +# lm1 = dspy.LM( +# model="openai/dspy-test-model", +# api_base=api_base, +# api_key="fakekey", +# ) +# lm1("Example query") +# # Remove the disk cache, after which the LM must rely on in-memory caching +# shutil.rmtree(temporary_blank_cache_dir) +# lm1("Example query2") +# lm1("Example query2") +# lm1("Example query2") +# lm1("Example query2") - request_logs = read_litellm_test_server_request_logs(server_log_file_path) - assert len(request_logs) == 2 +# request_logs = read_litellm_test_server_request_logs(server_log_file_path) +# assert len(request_logs) == 2 def test_lm_calls_skip_in_memory_cache_if_key_not_computable(): @@ -127,39 +129,39 @@ class NonJsonSerializable: assert mock_litellm_completion.call_count == 2 -def test_lm_calls_with_callables_are_cached_as_expected(): - with patch("litellm.completion") as mock_completion: - lm_with_callable = dspy.LM( - model="openai/dspy-test-model", - api_base="fakebase", - api_key="fakekey", - # Define a callable kwarg for the LM to use during inference - azure_ad_token_provider=lambda *args, **kwargs: None, - ) - # Invoke the LM twice; the second call should be cached in memory - lm_with_callable("Query") - lm_with_callable("Query") - - # Define and invoke a nearly-identical LM that lacks the callable kwarg, - # which should not hit the in-memory cache - lm_without_callable = dspy.LM( - model="openai/dspy-test-model", - api_base="fakebase", - api_key="fakekey", - ) - lm_without_callable("Query") - - assert mock_completion.call_count == 2 - - -def test_lms_called_expected_number_of_times_for_cache_key_generation_failures(): - with pytest.raises(Exception), patch("litellm.completion") as mock_completion: - mock_completion.side_effect = Exception("Mocked exception") - lm = dspy.LM( - model="openai/dspy-test-model", - api_base="fakebase", - api_key="fakekey", - ) - lm("Do not retry") - - assert mock_completion.call_count == 1 +# def test_lm_calls_with_callables_are_cached_as_expected(): +# with patch("litellm.completion") as mock_completion: +# lm_with_callable = dspy.LM( +# model="openai/dspy-test-model", +# api_base="fakebase", +# api_key="fakekey", +# # Define a callable kwarg for the LM to use during inference +# azure_ad_token_provider=lambda *args, **kwargs: None, +# ) +# # Invoke the LM twice; the second call should be cached in memory +# lm_with_callable("Query") +# lm_with_callable("Query") + +# # Define and invoke a nearly-identical LM that lacks the callable kwarg, +# # which should not hit the in-memory cache +# lm_without_callable = dspy.LM( +# model="openai/dspy-test-model", +# api_base="fakebase", +# api_key="fakekey", +# ) +# lm_without_callable("Query") + +# assert mock_completion.call_count == 2 + + +# def test_lms_called_expected_number_of_times_for_cache_key_generation_failures(): +# with pytest.raises(Exception), patch("litellm.completion") as mock_completion: +# mock_completion.side_effect = Exception("Mocked exception") +# lm = dspy.LM( +# model="openai/dspy-test-model", +# api_base="fakebase", +# api_key="fakekey", +# ) +# lm("Do not retry") + +# assert mock_completion.call_count == 1