File size: 2,534 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import copy
import dspy
import dsp
from .predict import Predict
class Retry(Predict):
def __init__(self, module):
super().__init__(module.signature)
self.module = module
self.original_signature = module.signature.signature
self.original_forward = module.forward
self.new_signature = self._create_new_signature(self.original_signature)
def _create_new_signature(self, original_signature):
extended_signature = {}
input_fields = original_signature.input_fields()
output_fields = original_signature.output_fields()
modified_output_fields = {}
for key, value in output_fields.items():
modified_output_fields[f"past_{key}"] = dspy.InputField(
prefix="Past " + value.prefix,
desc="past output with errors",
format=value.format,
)
extended_signature.update(input_fields)
extended_signature.update(modified_output_fields)
extended_signature["feedback"] = dspy.InputField(
prefix="Instructions:",
desc="Some instructions you must satisfy",
format=str,
)
extended_signature.update(output_fields)
return extended_signature
def forward(self, *args, **kwargs):
for key, value in kwargs["past_outputs"].items():
past_key = f"past_{key}"
if past_key in self.new_signature:
kwargs[past_key] = value
del kwargs["past_outputs"]
kwargs["new_signature"] = self.new_signature
return self.original_forward(**kwargs)
def __call__(self, **kwargs):
cached_kwargs = copy.deepcopy(kwargs)
kwargs["_trace"] = False
kwargs.setdefault("demos", self.demos if self.demos is not None else [])
# perform backtracking
if dspy.settings.backtrack_to == self:
for key, value in dspy.settings.backtrack_to_args.items():
kwargs.setdefault(key, value)
pred = self.forward(**kwargs)
else:
pred = self.module(**kwargs)
# now pop multiple reserved keys
# NOTE(shangyin) past_outputs seems not useful to include in demos,
# therefore dropped
for key in ["_trace", "demos", "signature", "config", "lm", "past_outputs"]:
kwargs.pop(key, None)
if dsp.settings.trace is not None:
trace = dsp.settings.trace
trace.append((self, {**kwargs}, pred))
return pred
|