|
import dsp |
|
|
|
from .predict import Predict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
class ChainOfThought(dspy.Module): |
|
def __init__(self, signature): |
|
|
|
input_fields, output_fields = dspy.process_signature(signature) |
|
output_fields = dict(rationale=dspy.OutputField(prefix="Reasoning: Let's think step by step."), **output_fields) |
|
self.signature = dspy.Signature(input_fields, output_fields) |
|
|
|
self.predict = dspy.Predict(self.signature) |
|
|
|
def forward(self, **kwargs): |
|
return self.predict(**kwargs) |
|
|
|
# How this should look like. But with also passing signature=simpler_signature to the predict module *if* deactivated. |
|
""" |
|
|
|
|
|
class ChainOfThought(Predict): |
|
def __init__(self, signature, rationale_type=None, activated=True, **config): |
|
super().__init__(signature, **config) |
|
|
|
self.activated = activated |
|
|
|
signature = self.signature |
|
*keys, last_key = signature.kwargs.keys() |
|
|
|
DEFAULT_RATIONALE_TYPE = dsp.Type( |
|
prefix="Raisonnement: Réfléchissons étape par étape afin de", |
|
desc="produire la ${" + last_key + "}. Nous ...", |
|
) |
|
|
|
rationale_type = rationale_type or DEFAULT_RATIONALE_TYPE |
|
|
|
extended_kwargs = {key: signature.kwargs[key] for key in keys} |
|
extended_kwargs.update( |
|
{"rationale": rationale_type, last_key: signature.kwargs[last_key]} |
|
) |
|
|
|
self.extended_signature = dsp.Template( |
|
signature.instructions, **extended_kwargs |
|
) |
|
|
|
def forward(self, **kwargs): |
|
new_signature = kwargs.pop("new_signature", None) |
|
if new_signature is None: |
|
if self.activated is True or ( |
|
self.activated is None and isinstance(dsp.settings.lm, dsp.GPT3) |
|
): |
|
signature = self.extended_signature |
|
else: |
|
signature = self.signature |
|
else: |
|
signature = dsp.Template(self.signature.instructions, **new_signature) |
|
return super().forward(signature=signature, **kwargs) |
|
|
|
|
|
def dump_state(self): |
|
state = super().dump_state() |
|
|
|
|
|
state["extended_signature_instructions"] = self.extended_signature.instructions |
|
state["extended_signature_prefix"] = self.extended_signature.fields[-1].name |
|
|
|
return state |
|
|
|
def load_state(self, state): |
|
super().load_state(state) |
|
|
|
|
|
if "extended_signature_instructions" in state: |
|
instructions = state["extended_signature_instructions"] |
|
self.extended_signature.instructions = instructions |
|
|
|
if "extended_signature_prefix" in state: |
|
prefix = state["extended_signature_prefix"] |
|
self.extended_signature.fields[-1] = self.extended_signature.fields[-1]._replace(name=prefix) |
|
|
|
""" |
|
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args. |
|
|
|
IF the user didn't overwrite our default rationale_type. |
|
""" |
|
|