File size: 3,549 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import dsp
from .predict import Predict
# TODO: FIXME: Insert this right before the *first* output field. Also rewrite this to use the new signature system.
# TODO: This shouldn't inherit from Predict. It should be a module that has one or two predictors.
# Let's focus on the activated case. It's a predictor with the expanded signature.
# Now, when deactivated, it's a predictor with the original signature.
# When activate is None, though, we need the expanded one but during forward we need to pass the right signature.
"""
class ChainOfThought(dspy.Module):
def __init__(self, signature):
input_fields, output_fields = dspy.process_signature(signature)
output_fields = dict(rationale=dspy.OutputField(prefix="Reasoning: Let's think step by step."), **output_fields)
self.signature = dspy.Signature(input_fields, output_fields)
self.predict = dspy.Predict(self.signature)
def forward(self, **kwargs):
return self.predict(**kwargs)
# How this should look like. But with also passing signature=simpler_signature to the predict module *if* deactivated.
"""
class ChainOfThought(Predict):
def __init__(self, signature, rationale_type=None, activated=True, **config):
super().__init__(signature, **config)
self.activated = activated
signature = self.signature
*keys, last_key = signature.kwargs.keys()
DEFAULT_RATIONALE_TYPE = dsp.Type(
prefix="Raisonnement: Réfléchissons étape par étape afin de",
desc="produire la ${" + last_key + "}. Nous ...",
)
rationale_type = rationale_type or DEFAULT_RATIONALE_TYPE
extended_kwargs = {key: signature.kwargs[key] for key in keys}
extended_kwargs.update(
{"rationale": rationale_type, last_key: signature.kwargs[last_key]}
)
self.extended_signature = dsp.Template(
signature.instructions, **extended_kwargs
)
def forward(self, **kwargs):
new_signature = kwargs.pop("new_signature", None)
if new_signature is None:
if self.activated is True or (
self.activated is None and isinstance(dsp.settings.lm, dsp.GPT3)
):
signature = self.extended_signature
else:
signature = self.signature
else:
signature = dsp.Template(self.signature.instructions, **new_signature)
return super().forward(signature=signature, **kwargs)
def dump_state(self):
state = super().dump_state()
# Cache the signature instructions and the last field's name.
state["extended_signature_instructions"] = self.extended_signature.instructions
state["extended_signature_prefix"] = self.extended_signature.fields[-1].name
return state
def load_state(self, state):
super().load_state(state)
# Reconstruct the signature.
if "extended_signature_instructions" in state:
instructions = state["extended_signature_instructions"]
self.extended_signature.instructions = instructions
if "extended_signature_prefix" in state:
prefix = state["extended_signature_prefix"]
self.extended_signature.fields[-1] = self.extended_signature.fields[-1]._replace(name=prefix)
"""
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.
IF the user didn't overwrite our default rationale_type.
"""
|