File size: 1,963 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import dsp
from .predict import Predict
# TODO: FIXME: Insert this right before the *first* output field. Also rewrite this to use the new signature system.
class ChainOfThoughtWithHint(Predict):
def __init__(self, signature, rationale_type=None, activated=True, **config):
super().__init__(signature, **config)
self.activated = activated
signature = self.signature
*keys, last_key = signature.kwargs.keys()
DEFAULT_HINT_TYPE = dsp.Type(prefix="Hint:", desc="${hint}")
DEFAULT_RATIONALE_TYPE = dsp.Type(prefix="Reasoning: Let's think step by step in order to",
desc="${produce the " + last_key + "}. We ...")
rationale_type = rationale_type or DEFAULT_RATIONALE_TYPE
extended_kwargs1 = {key: signature.kwargs[key] for key in keys}
extended_kwargs1.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]})
extended_kwargs2 = {key: signature.kwargs[key] for key in keys}
extended_kwargs2.update({'hint': DEFAULT_HINT_TYPE, 'rationale': rationale_type, last_key: signature.kwargs[last_key]})
self.extended_signature1 = dsp.Template(signature.instructions, **extended_kwargs1)
self.extended_signature2 = dsp.Template(signature.instructions, **extended_kwargs2)
def forward(self, **kwargs):
signature = self.signature
if self.activated is True or (self.activated is None and isinstance(dsp.settings.lm, dsp.GPT3)):
if 'hint' in kwargs and kwargs['hint']:
signature = self.extended_signature2
else:
signature = self.extended_signature1
return super().forward(signature=signature, **kwargs)
"""
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.
IF the user didn't overwrite our default rationale_type.
""" |