File size: 9,161 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import dsp
import tqdm
import random
import threading

import dspy
from dspy.predict.retry import Retry

from dspy.primitives import Example

from .teleprompt import Teleprompter
from .vanilla import LabeledFewShot

from dspy.evaluate.evaluate import Evaluate

# TODO: metrics should return an object with __bool__ basically, but fine if they're more complex.
# They can also be sortable.

# TODO: Switch here from dsp.Example to dspy.Example. Right now, it's okay because it's internal only (predictors).
# NOTE: Notice the places where we don't shuffle examples. I do like that this one doesn't shuffle.
# Other ones that consider options may want to use both unshuffled and then shuffle a few times, when considering candidates.

# TODO: the max_rounds via branch_idx to get past the cache, not just temperature.
# In principle, we can also sample multiple outputs from the final generation step
# (or even each step, in case the validation function just wants *one* thing that works, but nah)
# and try them all. Having a pretty solid guess on the "final step" of each example isn't hard by the second round,
# in the sense that we have the trace from the first round. (Yes it may change but that's an edge case that
# won't hurt our "best effort" guarantees.)

# TODO: When this bootstraps for another teleprompter like finetune, we want all demos we gather.
# But when it's for direct use we may want to sample ONE demo per predictor--example pair. This is important for "multi-use" modules.

# TODO: Add baselines=[...]


class BootstrapFewShot(Teleprompter):
    def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
        self.metric = metric
        self.teacher_settings = teacher_settings

        self.max_bootstrapped_demos = max_bootstrapped_demos
        self.max_labeled_demos = max_labeled_demos
        self.max_rounds = max_rounds
        self.max_errors= max_errors
        self.error_count = 0
        self.error_lock = threading.Lock()

    def compile(self, student, *, teacher=None, trainset, valset=None):
        self.trainset = trainset
        self.valset = valset

        self._prepare_student_and_teacher(student, teacher)
        self._prepare_predictor_mappings()
        self._bootstrap()

        self.student = self._train()
        self.student._compiled = True

        # set assert_failures and suggest_failures as attributes of student w/ value 0
        setattr(self.student, '_assert_failures', 0)
        setattr(self.student, '_suggest_failures', 0)

        return self.student
    
    def _prepare_student_and_teacher(self, student, teacher):
        self.student = student.reset_copy()
        self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy()

        assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled."

        if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False:
            teleprompter = LabeledFewShot(k=self.max_labeled_demos)
            self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset)

    def _prepare_predictor_mappings(self):
        name2predictor, predictor2name = {}, {}
        student, teacher = self.student, self.teacher

        assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors."

        for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()):
            assert name1 == name2, "Student and teacher must have the same program structure."
            assert predictor1.signature == predictor2.signature, f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
            assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."

            name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
            predictor2name[id(predictor1)] = name1

            # FIXME(shangyint): This is an ugly hack to bind traces of
            # retry.module to retry
            # if isinstance(predictor1, Retry):
            #     predictor2name[id(predictor1.module)] = name1

            predictor2name[id(predictor2)] = name2            

        self.name2predictor = name2predictor
        self.predictor2name = predictor2name

    def _bootstrap(self, *, max_bootstraps=None):
        max_bootstraps = max_bootstraps or self.max_bootstrapped_demos

        bootstrapped = {}
        self.name2traces = {name: [] for name in self.name2predictor}

        for round_idx in range(self.max_rounds):
            for example_idx, example in enumerate(tqdm.tqdm(self.trainset)):
                if len(bootstrapped) >= max_bootstraps:
                    break

                if example_idx not in bootstrapped:
                    success = self._bootstrap_one_example(example, round_idx)

                    if success:
                        bootstrapped[example_idx] = True

        print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.')
        
        # Unbootstrapped training examples

        self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped]
        random.Random(0).shuffle(self.validation)

        self.validation = self.valset or self.validation

        # NOTE: Can't yet use evaluate because we need to trace *per example*
        # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
        # score = evaluate(self.metric, display_table=False, display_progress=True)
    
    def _bootstrap_one_example(self, example, round_idx=0):
        name2traces = self.name2traces
        teacher = self.teacher #.deepcopy()
        predictor_cache = {}

        try:
            with dsp.settings.context(trace=[], **self.teacher_settings):
                lm = dsp.settings.lm
                lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm
                new_settings = dict(lm=lm) if round_idx > 0 else {}

                with dsp.settings.context(**new_settings):
                    for name, predictor in teacher.named_predictors():
                        predictor_cache[name] = predictor.demos
                        predictor.demos = [x for x in predictor.demos if x != example]

                    prediction = teacher(**example.inputs())
                    trace = dsp.settings.trace

                    for name, predictor in teacher.named_predictors():
                        predictor.demos = predictor_cache[name]

                success = (self.metric is None) or self.metric(example, prediction, trace)
                # print(success, example, prediction)
        except Exception as e:
            success = False
            with self.error_lock:
                self.error_count += 1
                current_error_count = self.error_count
            if current_error_count >= self.max_errors:
                raise e
            print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.')
        
        if success:
            for step in trace:
                predictor, inputs, outputs = step

                if 'dspy_uuid' in example:
                    demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs)
                else:
                    # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case.
                    demo = Example(augmented=True, **inputs, **outputs)

                try:
                    predictor_name = self.predictor2name[id(predictor)]
                except KeyError as e:
                    continue # FIXME: !

                    # TODO: Look closer into this. It's a bit tricky to reproduce.
                    print(f'Failed to find predictor {predictor} in {self.predictor2name}.')
                    print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.')
                    print('Try restarting the notebook, or open an issue.')
                    raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e

                name2traces[predictor_name].append(demo)
        
        return success

    def _train(self):
        rng = random.Random(0)
        raw_demos = self.validation

        for name, predictor in self.student.named_predictors():
            augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos]
            
            sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))
            sample_size = max(0, sample_size)

            raw_demos = rng.sample(raw_demos, sample_size)
            
            import dspy
            if dspy.settings.release >= 20230928:
                predictor.demos = raw_demos + augmented_demos
            else:
                predictor.demos = augmented_demos + raw_demos

        return self.student