|
import dsp |
|
import tqdm |
|
import random |
|
import threading |
|
|
|
import dspy |
|
from dspy.predict.retry import Retry |
|
|
|
from dspy.primitives import Example |
|
|
|
from .teleprompt import Teleprompter |
|
from .vanilla import LabeledFewShot |
|
|
|
from dspy.evaluate.evaluate import Evaluate |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BootstrapFewShot(Teleprompter): |
|
def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5): |
|
self.metric = metric |
|
self.teacher_settings = teacher_settings |
|
|
|
self.max_bootstrapped_demos = max_bootstrapped_demos |
|
self.max_labeled_demos = max_labeled_demos |
|
self.max_rounds = max_rounds |
|
self.max_errors= max_errors |
|
self.error_count = 0 |
|
self.error_lock = threading.Lock() |
|
|
|
def compile(self, student, *, teacher=None, trainset, valset=None): |
|
self.trainset = trainset |
|
self.valset = valset |
|
|
|
self._prepare_student_and_teacher(student, teacher) |
|
self._prepare_predictor_mappings() |
|
self._bootstrap() |
|
|
|
self.student = self._train() |
|
self.student._compiled = True |
|
|
|
|
|
setattr(self.student, '_assert_failures', 0) |
|
setattr(self.student, '_suggest_failures', 0) |
|
|
|
return self.student |
|
|
|
def _prepare_student_and_teacher(self, student, teacher): |
|
self.student = student.reset_copy() |
|
self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy() |
|
|
|
assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled." |
|
|
|
if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False: |
|
teleprompter = LabeledFewShot(k=self.max_labeled_demos) |
|
self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset) |
|
|
|
def _prepare_predictor_mappings(self): |
|
name2predictor, predictor2name = {}, {} |
|
student, teacher = self.student, self.teacher |
|
|
|
assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors." |
|
|
|
for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): |
|
assert name1 == name2, "Student and teacher must have the same program structure." |
|
assert predictor1.signature == predictor2.signature, f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" |
|
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." |
|
|
|
name2predictor[name1] = None |
|
predictor2name[id(predictor1)] = name1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
predictor2name[id(predictor2)] = name2 |
|
|
|
self.name2predictor = name2predictor |
|
self.predictor2name = predictor2name |
|
|
|
def _bootstrap(self, *, max_bootstraps=None): |
|
max_bootstraps = max_bootstraps or self.max_bootstrapped_demos |
|
|
|
bootstrapped = {} |
|
self.name2traces = {name: [] for name in self.name2predictor} |
|
|
|
for round_idx in range(self.max_rounds): |
|
for example_idx, example in enumerate(tqdm.tqdm(self.trainset)): |
|
if len(bootstrapped) >= max_bootstraps: |
|
break |
|
|
|
if example_idx not in bootstrapped: |
|
success = self._bootstrap_one_example(example, round_idx) |
|
|
|
if success: |
|
bootstrapped[example_idx] = True |
|
|
|
print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.') |
|
|
|
|
|
|
|
self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped] |
|
random.Random(0).shuffle(self.validation) |
|
|
|
self.validation = self.valset or self.validation |
|
|
|
|
|
|
|
|
|
|
|
def _bootstrap_one_example(self, example, round_idx=0): |
|
name2traces = self.name2traces |
|
teacher = self.teacher |
|
predictor_cache = {} |
|
|
|
try: |
|
with dsp.settings.context(trace=[], **self.teacher_settings): |
|
lm = dsp.settings.lm |
|
lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm |
|
new_settings = dict(lm=lm) if round_idx > 0 else {} |
|
|
|
with dsp.settings.context(**new_settings): |
|
for name, predictor in teacher.named_predictors(): |
|
predictor_cache[name] = predictor.demos |
|
predictor.demos = [x for x in predictor.demos if x != example] |
|
|
|
prediction = teacher(**example.inputs()) |
|
trace = dsp.settings.trace |
|
|
|
for name, predictor in teacher.named_predictors(): |
|
predictor.demos = predictor_cache[name] |
|
|
|
success = (self.metric is None) or self.metric(example, prediction, trace) |
|
|
|
except Exception as e: |
|
success = False |
|
with self.error_lock: |
|
self.error_count += 1 |
|
current_error_count = self.error_count |
|
if current_error_count >= self.max_errors: |
|
raise e |
|
print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.') |
|
|
|
if success: |
|
for step in trace: |
|
predictor, inputs, outputs = step |
|
|
|
if 'dspy_uuid' in example: |
|
demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs) |
|
else: |
|
|
|
demo = Example(augmented=True, **inputs, **outputs) |
|
|
|
try: |
|
predictor_name = self.predictor2name[id(predictor)] |
|
except KeyError as e: |
|
continue |
|
|
|
|
|
print(f'Failed to find predictor {predictor} in {self.predictor2name}.') |
|
print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.') |
|
print('Try restarting the notebook, or open an issue.') |
|
raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e |
|
|
|
name2traces[predictor_name].append(demo) |
|
|
|
return success |
|
|
|
def _train(self): |
|
rng = random.Random(0) |
|
raw_demos = self.validation |
|
|
|
for name, predictor in self.student.named_predictors(): |
|
augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos] |
|
|
|
sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos)) |
|
sample_size = max(0, sample_size) |
|
|
|
raw_demos = rng.sample(raw_demos, sample_size) |
|
|
|
import dspy |
|
if dspy.settings.release >= 20230928: |
|
predictor.demos = raw_demos + augmented_demos |
|
else: |
|
predictor.demos = augmented_demos + raw_demos |
|
|
|
return self.student |
|
|