EureCA / dspy /teleprompt /random_search.py
tonneli's picture
Delete history
f5776d3
import dsp
import tqdm
import random
from dspy.teleprompt.teleprompt import Teleprompter
from .bootstrap import BootstrapFewShot
from .vanilla import LabeledFewShot
from dspy.evaluate.evaluate import Evaluate
# TODO: Don't forget dealing with the raw demos.
# TODO: Deal with the (pretty common) case of having a metric for filtering and a separate metric for eval.
# The metric itself may tell though by the presence of trace.
# TODO: This function should take a max_budget and max_teacher_budget. That's in the number of program calls.
# In this case, max_student_budget is max_budget - max_teacher_budget.
# For max_teacher_budget, this will just limit the total number of things we bootstrap.
# This can end up implicitly defining the number of candidate programs (i.e., stop when runs out). Cap at 16.
# For max_student_budget, this will be a more upfront calculation.
# Right now, it can also just induce the number of candidate programs. Later, it could be used more interestingly
# for selective early stopping.
# Progressive elimination sounds about right: after 50 examples, drop bottom third, after 100, another third, etc.
# until only 3--5 are left for the end. Could also be systematic and add (earlier) stopping based on error bounds.
# In general, though, the early filtering is just saying: either there are some really bad ones, or some really really
# good ones, or most things are pretty close. In all of these cases, dropping the bottom third is not going to hurt.
class BootstrapFewShotWithRandomSearch(Teleprompter):
def __init__(self, metric, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, num_candidate_programs=16, num_threads=6, stop_at_score=None):
self.metric = metric
self.teacher_settings = teacher_settings
self.max_rounds = max_rounds
self.num_threads = num_threads
self.stop_at_score = stop_at_score
self.min_num_samples = 1
self.max_num_samples = max_bootstrapped_demos
self.num_candidate_sets = num_candidate_programs
# self.max_num_traces = 1 + int(max_bootstrapped_demos / 2.0 * self.num_candidate_sets)
# Semi-hacky way to get the parent class's _bootstrap function to stop early.
# self.max_bootstrapped_demos = self.max_num_traces
self.max_labeled_demos = max_labeled_demos
print("Going to sample between", self.min_num_samples, "and", self.max_num_samples, "traces per predictor.")
# print("Going to sample", self.max_num_traces, "traces in total.")
print("Will attempt to train", self.num_candidate_sets, "candidate sets.")
def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None):
self.trainset = trainset
self.valset = valset or trainset # TODO: FIXME: Note this choice.
scores = []
all_subscores = []
score_data = []
for seed in range(-3, self.num_candidate_sets):
if (restrict is not None) and (seed not in restrict):
print(seed, restrict)
continue
trainset2 = list(self.trainset)
if seed == -3:
# zero-shot
program2 = student.reset_copy()
elif seed == -2:
# labels only
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
program2 = teleprompter.compile(student, trainset=trainset2)
elif seed == -1:
# unshuffled few-shot
program = BootstrapFewShot(metric=self.metric, max_bootstrapped_demos=self.max_num_samples,
max_labeled_demos=self.max_labeled_demos,
teacher_settings=self.teacher_settings, max_rounds=self.max_rounds)
program2 = program.compile(student, teacher=teacher, trainset=trainset2)
else:
assert seed >= 0, seed
random.Random(seed).shuffle(trainset2)
size = random.Random(seed).randint(self.min_num_samples, self.max_num_samples)
teleprompter = BootstrapFewShot(metric=self.metric, max_bootstrapped_demos=size,
max_labeled_demos=self.max_labeled_demos,
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds)
program2 = teleprompter.compile(student, teacher=teacher, trainset=trainset2)
evaluate = Evaluate(devset=self.valset, metric=self.metric, num_threads=self.num_threads,
display_table=False, display_progress=True)
score, subscores = evaluate(program2, return_all_scores=True)
all_subscores.append(subscores)
############ Assertion-aware Optimization ############
if hasattr(program2, '_suggest_failures'):
score = score - program2._suggest_failures * 0.2
if hasattr(program2, '_assert_failures'):
score = 0 if program2._assert_failures > 0 else score
######################################################
print('Score:', score, 'for set:', [len(predictor.demos) for predictor in program2.predictors()])
if len(scores) == 0 or score > max(scores):
print('New best score:', score, 'for seed', seed)
best_program = program2
scores.append(score)
print(f"Scores so far: {scores}")
print('Best score:', max(scores))
score_data.append((score, subscores, seed, program2))
if len(score_data) > 2: # We check if there are at least 3 scores to consider
for k in [1, 2, 3, 5, 8, 9999]:
top_3_scores = sorted(score_data, key=lambda x: x[0], reverse=True)[:k]
# Transpose the subscores to get max per entry and then calculate their average
transposed_subscores = zip(*[subscores for _, subscores, *_ in top_3_scores if subscores])
avg_of_max_per_entry = sum(max(entry) for entry in transposed_subscores) / len(top_3_scores[0][1])
print(f'Average of max per entry across top {k} scores: {avg_of_max_per_entry}')
if self.stop_at_score is not None and score >= self.stop_at_score:
print(f"Stopping early because score {score} is >= stop_at_score {self.stop_at_score}")
break
# To best program, attach all program candidates in decreasing average score
best_program.candidate_programs = score_data
best_program.candidate_programs = sorted(best_program.candidate_programs, key=lambda x: x[0], reverse=True)
print(len(best_program.candidate_programs), "candidate programs found.")
return best_program
# sample between 4 and 10 examples from traces
# TODO: FIXME: The max number of demos should be determined in part by the LM's tokenizer + max_length.
# This does require executing the program, or at least the predictor.
# # # # # # (Actually we can just combine the token counts of the traces, when formatted via signature/adapter).
# Alternatively, we can keep track of the (zero-shot) number of tokens when we bootstrap.
# As another option, we can just try a wide range and handle failures as penalties on the score.
# The number "24" of traces to collect can also be affected. If we only need 3x10, some overlap is ok.
# We can also consider having short_demos and long_demos.