import dsp import tqdm import random import threading import dspy from dspy.predict.retry import Retry from dspy.primitives import Example from .teleprompt import Teleprompter from .vanilla import LabeledFewShot from dspy.evaluate.evaluate import Evaluate # TODO: metrics should return an object with __bool__ basically, but fine if they're more complex. # They can also be sortable. # TODO: Switch here from dsp.Example to dspy.Example. Right now, it's okay because it's internal only (predictors). # NOTE: Notice the places where we don't shuffle examples. I do like that this one doesn't shuffle. # Other ones that consider options may want to use both unshuffled and then shuffle a few times, when considering candidates. # TODO: the max_rounds via branch_idx to get past the cache, not just temperature. # In principle, we can also sample multiple outputs from the final generation step # (or even each step, in case the validation function just wants *one* thing that works, but nah) # and try them all. Having a pretty solid guess on the "final step" of each example isn't hard by the second round, # in the sense that we have the trace from the first round. (Yes it may change but that's an edge case that # won't hurt our "best effort" guarantees.) # TODO: When this bootstraps for another teleprompter like finetune, we want all demos we gather. # But when it's for direct use we may want to sample ONE demo per predictor--example pair. This is important for "multi-use" modules. # TODO: Add baselines=[...] class BootstrapFewShot(Teleprompter): def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5): self.metric = metric self.teacher_settings = teacher_settings self.max_bootstrapped_demos = max_bootstrapped_demos self.max_labeled_demos = max_labeled_demos self.max_rounds = max_rounds self.max_errors= max_errors self.error_count = 0 self.error_lock = threading.Lock() def compile(self, student, *, teacher=None, trainset, valset=None): self.trainset = trainset self.valset = valset self._prepare_student_and_teacher(student, teacher) self._prepare_predictor_mappings() self._bootstrap() self.student = self._train() self.student._compiled = True # set assert_failures and suggest_failures as attributes of student w/ value 0 setattr(self.student, '_assert_failures', 0) setattr(self.student, '_suggest_failures', 0) return self.student def _prepare_student_and_teacher(self, student, teacher): self.student = student.reset_copy() self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy() assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled." if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False: teleprompter = LabeledFewShot(k=self.max_labeled_demos) self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset) def _prepare_predictor_mappings(self): name2predictor, predictor2name = {}, {} student, teacher = self.student, self.teacher assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors." for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): assert name1 == name2, "Student and teacher must have the same program structure." assert predictor1.signature == predictor2.signature, f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) predictor2name[id(predictor1)] = name1 # FIXME(shangyint): This is an ugly hack to bind traces of # retry.module to retry # if isinstance(predictor1, Retry): # predictor2name[id(predictor1.module)] = name1 predictor2name[id(predictor2)] = name2 self.name2predictor = name2predictor self.predictor2name = predictor2name def _bootstrap(self, *, max_bootstraps=None): max_bootstraps = max_bootstraps or self.max_bootstrapped_demos bootstrapped = {} self.name2traces = {name: [] for name in self.name2predictor} for round_idx in range(self.max_rounds): for example_idx, example in enumerate(tqdm.tqdm(self.trainset)): if len(bootstrapped) >= max_bootstraps: break if example_idx not in bootstrapped: success = self._bootstrap_one_example(example, round_idx) if success: bootstrapped[example_idx] = True print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.') # Unbootstrapped training examples self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped] random.Random(0).shuffle(self.validation) self.validation = self.valset or self.validation # NOTE: Can't yet use evaluate because we need to trace *per example* # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) # score = evaluate(self.metric, display_table=False, display_progress=True) def _bootstrap_one_example(self, example, round_idx=0): name2traces = self.name2traces teacher = self.teacher #.deepcopy() predictor_cache = {} try: with dsp.settings.context(trace=[], **self.teacher_settings): lm = dsp.settings.lm lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm new_settings = dict(lm=lm) if round_idx > 0 else {} with dsp.settings.context(**new_settings): for name, predictor in teacher.named_predictors(): predictor_cache[name] = predictor.demos predictor.demos = [x for x in predictor.demos if x != example] prediction = teacher(**example.inputs()) trace = dsp.settings.trace for name, predictor in teacher.named_predictors(): predictor.demos = predictor_cache[name] success = (self.metric is None) or self.metric(example, prediction, trace) # print(success, example, prediction) except Exception as e: success = False with self.error_lock: self.error_count += 1 current_error_count = self.error_count if current_error_count >= self.max_errors: raise e print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.') if success: for step in trace: predictor, inputs, outputs = step if 'dspy_uuid' in example: demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs) else: # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case. demo = Example(augmented=True, **inputs, **outputs) try: predictor_name = self.predictor2name[id(predictor)] except KeyError as e: continue # FIXME: ! # TODO: Look closer into this. It's a bit tricky to reproduce. print(f'Failed to find predictor {predictor} in {self.predictor2name}.') print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.') print('Try restarting the notebook, or open an issue.') raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e name2traces[predictor_name].append(demo) return success def _train(self): rng = random.Random(0) raw_demos = self.validation for name, predictor in self.student.named_predictors(): augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos] sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos)) sample_size = max(0, sample_size) raw_demos = rng.sample(raw_demos, sample_size) import dspy if dspy.settings.release >= 20230928: predictor.demos = raw_demos + augmented_demos else: predictor.demos = augmented_demos + raw_demos return self.student