Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIPROv2 KNN #1888

Open
wants to merge 39 commits into
base: bootstrap-knn-few-shot-with-random-search
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
627afb9
Parallel Bootstrap
CyrusNuevoDia Nov 25, 2024
254c8e5
Merge branch 'bootstrap-knn-few-shot-with-random-search' into paralle…
CyrusNuevoDia Nov 25, 2024
7816ba3
dspy.streamify
CyrusNuevoDia Nov 29, 2024
d8bc33c
Update docs
CyrusNuevoDia Nov 29, 2024
48a2376
Merge branch 'main' into streaming
CyrusNuevoDia Nov 29, 2024
ab8b28e
Fix ruff lint error
CyrusNuevoDia Nov 29, 2024
7c251a9
Bring back send_stream to settings
CyrusNuevoDia Nov 29, 2024
00c6e76
Improve doc
CyrusNuevoDia Nov 29, 2024
c02a7fc
Bring back request_cache setting
CyrusNuevoDia Nov 29, 2024
707e5a3
sse => streaming_response
CyrusNuevoDia Nov 29, 2024
e852b4c
Simplify dsp.utils.settings diff
CyrusNuevoDia Nov 29, 2024
1e8dcf8
WIP
CyrusNuevoDia Nov 30, 2024
8cc7a15
Merge branch 'bootstrap-knn-few-shot-with-random-search' into miprov2…
CyrusNuevoDia Dec 1, 2024
bf58b36
Fixed MIPROv2KNN bug
CyrusNuevoDia Dec 2, 2024
d4064e1
Add load/dump to LRUCache + drop callable request params
CyrusNuevoDia Dec 2, 2024
60a95bc
Merge branch 'streaming' into miprov2-knn
CyrusNuevoDia Dec 2, 2024
e87afcd
Remove testing results.csv
CyrusNuevoDia Dec 3, 2024
3111b12
ujson => pickle for dump/load
CyrusNuevoDia Dec 3, 2024
dc2b06e
EOD
CyrusNuevoDia Dec 3, 2024
26814a9
Parallel proposer
CyrusNuevoDia Dec 3, 2024
a698611
Remove double logger init
CyrusNuevoDia Dec 3, 2024
64a926e
Refactor + caching dspy.Embedder
CyrusNuevoDia Dec 4, 2024
2386fc3
Add lazy vectorization to dspy.KNN
CyrusNuevoDia Dec 4, 2024
ce59dc0
Parallel GroundedProposer
CyrusNuevoDia Dec 4, 2024
9082927
Update BootstrapKNNRS to use lazy KNN
CyrusNuevoDia Dec 4, 2024
acaba56
Improve progress logging for Bootstrap*
CyrusNuevoDia Dec 4, 2024
a50a9f5
Simplify OptimizerTester
CyrusNuevoDia Dec 4, 2024
1da5c4f
Formatting and pass through num_threads to GroundedProposer
CyrusNuevoDia Dec 4, 2024
fe4c5c9
MIPROv2KNN works!
CyrusNuevoDia Dec 4, 2024
c5b9c50
Merge branch 'bootstrap-knn-few-shot-with-random-search' into miprov2…
CyrusNuevoDia Dec 4, 2024
b049a4c
Fix lazy loading for KNN
CyrusNuevoDia Dec 5, 2024
b0a64f9
Make num_candidate_programs actually = num_candidate_programs
CyrusNuevoDia Dec 5, 2024
5074b13
Clarify embedding log
CyrusNuevoDia Dec 5, 2024
96844ef
Remove unnecessary import
CyrusNuevoDia Dec 5, 2024
2e651a0
Improve printing
CyrusNuevoDia Dec 5, 2024
5156f72
Add 2-module HoVeR task
CyrusNuevoDia Dec 6, 2024
9a42713
Answer exact match and semantic f1
CyrusNuevoDia Dec 18, 2024
a50e1ef
Merge branch 'bootstrap-knn-few-shot-with-random-search' into miprov2…
CyrusNuevoDia Dec 18, 2024
0337e74
Typed DemoCandidate
CyrusNuevoDia Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
/ScoNe/
testing/outputs/
testing/playbook.ipynb
testing/outputs/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
142 changes: 142 additions & 0 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import copy
import threading
from contextlib import contextmanager
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
lm=None,
adapter=None,
rm=None,
branch_idx=0,
reranker=None,
compiled_lm=None,
force_reuse_cached_compilation=False,
compiling=False,
skip_logprobs=False,
trace=[],
release=0,
bypass_assert=False,
bypass_suggest=False,
assert_failures=0,
suggest_failures=0,
langchain_history=[],
experimental=False,
backoff_time=10,
callbacks=[],
async_max_workers=8,
request_cache=None,
send_stream=None,
)

# Global base configuration
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)


class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict() # Initialize thread-local overrides


# Create the thread-local storage
thread_local_overrides = ThreadLocalOverrides()


class Settings:
"""
A singleton class for DSPy configuration settings.

This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
- If native threading is used, the thread inherits the initial config from the main thread.
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
"""

_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py
return cls._instance

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
if name in overrides:
return overrides[name]
elif name in main_thread_config:
return main_thread_config[name]
else:
raise AttributeError(f"'Settings' object has no attribute '{name}'")

def __setattr__(self, name, value):
if name in ('_instance',):
super().__setattr__(name, value)
else:
self.configure(**{name: value})

# Dictionary-like access

def __getitem__(self, key):
return self.__getattr__(key)

def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return key in overrides or key in main_thread_config

def get(self, key, default=None):
try:
return self[key]
except AttributeError:
return default

def copy(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return dotdict({**main_thread_config, **overrides})

@property
def config(self):
config = self.copy()
if 'lock' in config:
del config['lock']
return config

# Configuration methods

def configure(self, **kwargs):
global main_thread_config

# Get or initialize thread-local overrides
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
thread_local_overrides.overrides = dotdict(
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
)

# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = thread_local_overrides.overrides

@contextmanager
def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
global main_thread_config
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_main_thread_config = main_thread_config.copy()

self.configure(**kwargs)
try:
yield
finally:
thread_local_overrides.overrides = original_overrides

if threading.current_thread() is threading.main_thread():
main_thread_config = original_main_thread_config

def __repr__(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)


settings = Settings()
3 changes: 2 additions & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dspy.teleprompt import *

import dspy.retrievers
import dspy.teleprompt

from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
Expand All @@ -25,6 +24,7 @@
configure = settings.configure
context = settings.context

import dspy.teleprompt

LabeledFewShot = dspy.teleprompt.LabeledFewShot
BootstrapFewShot = dspy.teleprompt.BootstrapFewShot
Expand All @@ -36,4 +36,5 @@
BetterTogether = dspy.teleprompt.BetterTogether
COPRO = dspy.teleprompt.COPRO
MIPROv2 = dspy.teleprompt.MIPROv2
MIPROv2KNN = dspy.teleprompt.MIPROv2KNN
Ensemble = dspy.teleprompt.Ensemble
14 changes: 8 additions & 6 deletions dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dspy.clients.lm import LM
from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.base_lm import BaseLM, inspect_history
from dspy.clients.embedding import Embedder
import litellm
import os
from pathlib import Path
import os

from litellm.caching import Cache
import litellm

from dspy.clients.base_lm import BaseLM, inspect_history
from dspy.clients.embedding import Embedder
from dspy.clients.lm import LM
from dspy.clients.provider import Provider, TrainingJob

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default
Expand Down
85 changes: 43 additions & 42 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Callable, List, Optional, Union
import litellm
import numpy as np

from .lm import request_cache


class Embedder:
"""DSPy embedding class.
Expand Down Expand Up @@ -56,13 +59,28 @@ def my_embedder(texts):
```
"""

def __init__(self, model, batch_size=200, caching=True, **kwargs):
def __init__(self, model: Union[str, Callable], batch_size=200, **kwargs):
if not isinstance(model, str) and not callable(model):
raise ValueError(f"`model` in `dspy.Embedder` must be a string or a callable, but got {type(model)}.")

self.model = model
self.batch_size = batch_size
self.caching = caching
self.default_kwargs = kwargs

def __call__(self, inputs, batch_size=None, caching=None, **kwargs):
def _embed(self, inputs: List[str], cache: bool, **kwargs):
if callable(self.model):
return self.model(inputs, **kwargs)

response = litellm_embedding({"model": self.model, "input": inputs, **kwargs}).data
return [data["embedding"] for data in response]

def __call__(
self,
inputs: Union[str, List[str]],
batch_size: Optional[int] = None,
cache: Optional[bool] = None,
**kwargs,
) -> np.ndarray:
"""Compute embeddings for the given inputs.

Args:
Expand All @@ -76,46 +94,29 @@ def __call__(self, inputs, batch_size=None, caching=None, **kwargs):
If the input is a list of strings, returns a 2D numpy array of embeddings, one embedding per row.
"""

if isinstance(inputs, str):
is_single_input = True
multi_input = isinstance(inputs, list)
if not multi_input:
inputs = [inputs]
else:
is_single_input = False

assert all(isinstance(inp, str) for inp in inputs), "All inputs must be strings."

if batch_size is None:
batch_size = self.batch_size
if caching is None:
caching = self.caching

merged_kwargs = self.default_kwargs.copy()
merged_kwargs.update(kwargs)

embeddings_list = []

def chunk(inputs_list, size):
for i in range(0, len(inputs_list), size):
yield inputs_list[i : i + size]

for batch_inputs in chunk(inputs, batch_size):
if isinstance(self.model, str):
embedding_response = litellm.embedding(
model=self.model, input=batch_inputs, caching=caching, **merged_kwargs
)
batch_embeddings = [data["embedding"] for data in embedding_response.data]
elif callable(self.model):
batch_embeddings = self.model(batch_inputs, **merged_kwargs)
else:
raise ValueError(
f"`model` in `dspy.Embedder` must be a string or a callable, but got {type(self.model)}."
)

embeddings_list.extend(batch_embeddings)

embeddings = np.array(embeddings_list, dtype=np.float32)

if is_single_input:
return embeddings[0]
else:
return embeddings
batch_size = batch_size or self.batch_size
kwargs = {**self.default_kwargs, **kwargs}

embeddings = flatten([self._embed(c, cache, **kwargs) for c in chunk(inputs, batch_size)])
embeddings = embeddings if multi_input else embeddings[0]
return np.array(embeddings, dtype=np.float32)


def chunk(inputs_list, size):
for i in range(0, len(inputs_list), size):
yield inputs_list[i : i + size]


def flatten(list_of_lists):
return [item for sublist in list_of_lists for item in sublist]


@request_cache(maxsize=None)
def litellm_embedding(request):
return litellm.embedding(**request, cache={"no-cache": False, "no-store": False})
23 changes: 16 additions & 7 deletions dspy/evaluate/auto_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def __init__(self, threshold=0.66, decompositional=False):
self.module = dspy.ChainOfThought(SemanticRecallPrecision)

def forward(self, example, pred, trace=None):
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
ground_truth = example.response if hasattr(example, "response") else getattr(example, "answer", None)
system_response = pred.response if hasattr(pred, "response") else getattr(pred, "answer", None)

scores = self.module(question=example.question, ground_truth=ground_truth, system_response=system_response)
score = f1_score(scores.precision, scores.recall)

return score if trace is None else score >= self.threshold



###########


Expand All @@ -70,7 +72,6 @@ class AnswerCompleteness(dspy.Signature):
completeness: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")



class AnswerGroundedness(dspy.Signature):
"""
Estimate the groundedness of a system's responses, against real retrieved documents written by people.
Expand All @@ -81,9 +82,13 @@ class AnswerGroundedness(dspy.Signature):
question: str = dspy.InputField()
retrieved_context: str = dspy.InputField()
system_response: str = dspy.InputField()
system_response_claims: str = dspy.OutputField(desc="enumeration of non-trivial or check-worthy claims in the system response")
system_response_claims: str = dspy.OutputField(
desc="enumeration of non-trivial or check-worthy claims in the system response"
)
discussion: str = dspy.OutputField(desc="discussion of how supported the claims are by the retrieved context")
groundedness: float = dspy.OutputField(desc="fraction (out of 1.0) of system response supported by the retrieved context")
groundedness: float = dspy.OutputField(
desc="fraction (out of 1.0) of system response supported by the retrieved context"
)


class CompleteAndGrounded(dspy.Module):
Expand All @@ -93,8 +98,12 @@ def __init__(self, threshold=0.66):
self.groundedness_module = dspy.ChainOfThought(AnswerGroundedness)

def forward(self, example, pred, trace=None):
completeness = self.completeness_module(question=example.question, ground_truth=example.response, system_response=pred.response)
groundedness = self.groundedness_module(question=example.question, retrieved_context=pred.context, system_response=pred.response)
completeness = self.completeness_module(
question=example.question, ground_truth=example.response, system_response=pred.response
)
groundedness = self.groundedness_module(
question=example.question, retrieved_context=pred.context, system_response=pred.response
)
score = f1_score(groundedness.groundedness, completeness.completeness)

return score if trace is None else score >= self.threshold
3 changes: 0 additions & 3 deletions dspy/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def HTML(x: str) -> str:
logger = logging.getLogger(__name__)


logger = logging.getLogger(__name__)


class Evaluate:
def __init__(
self,
Expand Down
Loading
Loading