File size: 2,534 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import copy
import dspy
import dsp

from .predict import Predict


class Retry(Predict):
    def __init__(self, module):
        super().__init__(module.signature)
        self.module = module
        self.original_signature = module.signature.signature
        self.original_forward = module.forward
        self.new_signature = self._create_new_signature(self.original_signature)

    def _create_new_signature(self, original_signature):
        extended_signature = {}
        input_fields = original_signature.input_fields()
        output_fields = original_signature.output_fields()
        modified_output_fields = {}

        for key, value in output_fields.items():
            modified_output_fields[f"past_{key}"] = dspy.InputField(
                prefix="Past " + value.prefix,
                desc="past output with errors",
                format=value.format,
            )

        extended_signature.update(input_fields)
        extended_signature.update(modified_output_fields)

        extended_signature["feedback"] = dspy.InputField(
            prefix="Instructions:",
            desc="Some instructions you must satisfy",
            format=str,
        )
        extended_signature.update(output_fields)

        return extended_signature

    def forward(self, *args, **kwargs):
        for key, value in kwargs["past_outputs"].items():
            past_key = f"past_{key}"
            if past_key in self.new_signature:
                kwargs[past_key] = value
        del kwargs["past_outputs"]
        kwargs["new_signature"] = self.new_signature
        return self.original_forward(**kwargs)
    
    def __call__(self, **kwargs):
        cached_kwargs = copy.deepcopy(kwargs)
        kwargs["_trace"] = False
        kwargs.setdefault("demos", self.demos if self.demos is not None else [])

        # perform backtracking
        if dspy.settings.backtrack_to == self:
            for key, value in dspy.settings.backtrack_to_args.items():
                kwargs.setdefault(key, value)
            pred = self.forward(**kwargs)
        else:
            pred = self.module(**kwargs)

        # now pop multiple reserved keys
        # NOTE(shangyin) past_outputs seems not useful to include in demos,
        # therefore dropped
        for key in ["_trace", "demos", "signature", "config", "lm", "past_outputs"]:
            kwargs.pop(key, None)

        if dsp.settings.trace is not None:
            trace = dsp.settings.trace
            trace.append((self, {**kwargs}, pred))
        return pred