Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel BootstrapFewShot, BootstrapKNN (and +WithRandomSearch) #1858

Open
wants to merge 2 commits into
base: bootstrap-knn-few-shot-with-random-search
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 95 additions & 74 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, Future, as_completed
import logging
import random
import threading
Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(
max_labeled_demos=16,
max_rounds=1,
max_errors=5,
num_threads=6,
):
"""
A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
Expand Down Expand Up @@ -75,14 +78,16 @@ def __init__(
self.max_labeled_demos = max_labeled_demos
self.max_rounds = max_rounds
self.max_errors = max_errors

self.num_threads = num_threads

self.error_count = 0
self.error_lock = threading.Lock()
self._lock = threading.Lock()

def compile(self, student, *, teacher=None, trainset):
self.trainset = trainset

self._prepare_student_and_teacher(student, teacher)
self._prepare_predictor_mappings()
self._bootstrap()

self.student = self._train()
Expand All @@ -106,10 +111,9 @@ def _prepare_student_and_teacher(self, student, teacher):
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset)

def _prepare_predictor_mappings(self):
name2predictor, predictor2name = {}, {}
student, teacher = self.student, self.teacher
self._assert_student_teacher_compatibility(self.student, self.teacher)

def _assert_student_teacher_compatibility(self, student, teacher):
assert len(student.predictors()) == len(
teacher.predictors(),
), "Student and teacher must have the same number of predictors."
Expand All @@ -131,58 +135,79 @@ def _prepare_predictor_mappings(self):
)
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."

name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
predictor2name[id(predictor1)] = name1

# FIXME(shangyint): This is an ugly hack to bind traces of
# retry.module to retry
# if isinstance(predictor1, Retry):
# predictor2name[id(predictor1.module)] = name1
def _prepare_predictor_mappings(self, student, teacher):
predictor2name = {}

for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()):
predictor2name[id(predictor1)] = name1
predictor2name[id(predictor2)] = name2

self.name2predictor = name2predictor
self.predictor2name = predictor2name
return predictor2name

def _bootstrap(self, *, max_bootstraps=None):
max_bootstraps = max_bootstraps or self.max_bootstrapped_demos
bootstrap_attempts = 0
rounds_attempted = 0
bootstrapped = set()
self.name2traces = defaultdict(list)

for round_idx in range(self.max_rounds):
rounds_attempted += 1

progress_bar = tqdm.tqdm(total=len(self.trainset))

futures: dict[Future, int] = {}

bootstrapped = {}
self.name2traces = {name: [] for name in self.name2predictor}
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
for example_idx, example in enumerate(self.trainset):
f = executor.submit(self._bootstrap_one_example, example, round_idx)
futures[f] = example_idx

for f in as_completed(futures.keys()):
if f.cancelled():
continue

success, name2traces = f.result()
if not success:
continue

progress_bar.update(1)
bootstrapped.add(futures[f])
for name, traces in name2traces.items():
self.name2traces[name].extend(traces)

if len(bootstrapped) >= max_bootstraps:
for f in futures:
if not f.done():
f.cancel()
break

for example_idx, example in enumerate(tqdm.tqdm(self.trainset)):
if len(bootstrapped) >= max_bootstraps:
break

for round_idx in range(self.max_rounds):
bootstrap_attempts += 1

if self._bootstrap_one_example(example, round_idx):
bootstrapped[example_idx] = True
break
progress_bar.close()

print(
f"Bootstrapped {len(bootstrapped)} full traces after {example_idx} examples "
f"for up to {self.max_rounds} rounds, amounting to {bootstrap_attempts} attempts."
f"for up to {self.max_rounds} rounds, amounting to {rounds_attempted} attempts."
)

# Unbootstrapped training examples

self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped]
random.Random(0).shuffle(self.validation)

self.validation = self.validation

# NOTE: Can't yet use evaluate because we need to trace *per example*
# evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
# score = evaluate(self.metric, display_table=False, display_progress=True)

def _bootstrap_one_example(self, example, round_idx=0):
name2traces = {} # self.name2traces
teacher = self.teacher # .deepcopy()
teacher = self.teacher.deepcopy()
predictor2name = self._prepare_predictor_mappings(self.student, teacher)

predictor_cache = {}

trace = []

try:
with dspy.settings.context(trace=[], **self.teacher_settings):
lm = dspy.settings.lm
Expand All @@ -195,7 +220,7 @@ def _bootstrap_one_example(self, example, round_idx=0):
predictor.demos = [x for x in predictor.demos if x != example]

prediction = teacher(**example.inputs())
trace = dspy.settings.trace
trace = dspy.settings.trace[:]

for name, predictor in teacher.named_predictors():
predictor.demos = predictor_cache[name]
Expand All @@ -210,48 +235,40 @@ def _bootstrap_one_example(self, example, round_idx=0):
success = True
except Exception as e:
success = False
with self.error_lock:
with self._lock:
self.error_count += 1
current_error_count = self.error_count
if current_error_count >= self.max_errors:
raise e
logger.error(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.")

if success:
for step in trace:
predictor, inputs, outputs = step
demo = dspy.Example(augmented=True, **inputs, **outputs).with_inputs(*list(inputs.keys()))

try:
predictor_name = self.predictor2name[id(predictor)]
except KeyError:
continue # FIXME: !

# # TODO: Look closer into this. It's a bit tricky to reproduce.
# print(f"Failed to find predictor {predictor} in {self.predictor2name}.")
# print(
# "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.",
# )
# print("Try restarting the notebook, or open an issue.")
# raise KeyError(
# f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.",
# ) from e

name2traces[predictor_name] = name2traces.get(predictor_name, [])
name2traces[predictor_name].append(demo)

# Update the traces
for name, demos in name2traces.items():
from datasets.fingerprint import Hasher

# If there are multiple traces for the same predictor in the sample example,
# sample 50/50 from the first N-1 traces or the last trace.
if len(demos) > 1:
rng = random.Random(Hasher.hash(tuple(demos)))
demos = [rng.choice(demos[:-1]) if rng.random() < 0.5 else demos[-1]]
self.name2traces[name].extend(demos)

return success
if not success:
return False, {}

assert trace, "No trace found."

name2traces = defaultdict(list)
for step in trace:
predictor, inputs, outputs = step
demo = dspy.Example(augmented=True, **inputs, **outputs).with_inputs(*list(inputs.keys()))

predictor_name = predictor2name[id(predictor)]
name2traces[predictor_name].append(demo)

# Update the traces
final_demos = {}
for name, demos in name2traces.items():
from datasets.fingerprint import Hasher

# If there are multiple traces for the same predictor in the sample example,
# sample 50/50 from the first N-1 traces or the last trace.
if len(demos) > 1:
rng = random.Random(Hasher.hash(tuple(demos)))
demos = [rng.choice(demos[:-1]) if rng.random() < 0.5 else demos[-1]]

final_demos[name] = demos

return True, final_demos

def _train(self):
rng = random.Random(0)
Expand Down Expand Up @@ -282,6 +299,7 @@ def __init__(
teacher_settings: Optional[Dict] = None,
max_bootstrapped_demos=64,
num_static_demos=0,
num_threads=6,
max_labeled_demos=16,
max_rounds=1,
max_errors=10,
Expand All @@ -297,24 +315,27 @@ def __init__(
max_labeled_demos=max_labeled_demos,
max_rounds=max_rounds,
max_errors=max_errors,
num_threads=num_threads,
)
self.num_static_demos = num_static_demos
self.embedder = embedder
self.num_static_demos = num_static_demos
self.random_seed = random_seed

def _train(self):
rng = random.Random(0)
k = self.max_labeled_demos - self.num_static_demos

for name, predictor in self.student.named_predictors():
predictor.random_seed = self.random_seed

augmented_demos = self.name2traces[name]

augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]
static_demos = rng.sample(augmented_demos, k=self.num_static_demos)
predictor.demos = static_demos

dynamic_demos = [x for x in augmented_demos if x not in static_demos]
predictor.retrieve_demos = dspy.KNN(k=k, trainset=dynamic_demos, vectorizer=self.embedder)

predictor.random_seed = self.random_seed
predictor.demos = static_demos
predictor.retrieve_demos = dspy.KNN(
k=k,
trainset=dynamic_demos,
vectorizer=self.embedder,
)

return self.student
4 changes: 4 additions & 0 deletions dspy/teleprompt/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds,
max_errors=self.max_errors,
num_threads=self.num_threads,
)
program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)

Expand All @@ -103,6 +104,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds,
max_errors=self.max_errors,
num_threads=self.num_threads,
)

program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)
Expand Down Expand Up @@ -217,6 +219,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds,
max_errors=self.max_errors,
num_threads=self.num_threads,
)
program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)

Expand All @@ -233,6 +236,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
max_rounds=self.max_rounds,
max_errors=self.max_errors,
num_static_demos=num_static_demos,
num_threads=self.num_threads,
random_seed=seed,
)

Expand Down
Loading