|
import dsp |
|
import random |
|
|
|
from dspy.predict.parameter import Parameter |
|
from dspy.primitives.prediction import Prediction |
|
from dspy.signatures.field import InputField, OutputField |
|
from dspy.signatures.signature import infer_prefix |
|
|
|
|
|
class Predict(Parameter): |
|
def __init__(self, signature, **config): |
|
self.stage = random.randbytes(8).hex() |
|
self.signature = signature |
|
self.config = config |
|
self.reset() |
|
|
|
|
|
if isinstance(signature, str): |
|
inputs, outputs = signature.split("->") |
|
inputs, outputs = inputs.split(","), outputs.split(",") |
|
inputs, outputs = [field.strip() for field in inputs], [field.strip() for field in outputs] |
|
|
|
assert all(len(field.split()) == 1 for field in (inputs + outputs)) |
|
|
|
inputs_ = ', '.join([f"`{field}`" for field in inputs]) |
|
outputs_ = ', '.join([f"`{field}`" for field in outputs]) |
|
|
|
instructions = f"""Given the fields {inputs_}, produce the fields {outputs_}.""" |
|
|
|
inputs = {k: InputField() for k in inputs} |
|
outputs = {k: OutputField() for k in outputs} |
|
|
|
for k, v in inputs.items(): |
|
v.finalize(k, infer_prefix(k)) |
|
|
|
for k, v in outputs.items(): |
|
v.finalize(k, infer_prefix(k)) |
|
|
|
self.signature = dsp.Template(instructions, **inputs, **outputs) |
|
|
|
|
|
def reset(self): |
|
self.lm = None |
|
self.traces = [] |
|
self.train = [] |
|
self.demos = [] |
|
|
|
def dump_state(self): |
|
state_keys = ["lm", "traces", "train", "demos"] |
|
state = {k: getattr(self, k) for k in state_keys} |
|
|
|
|
|
state["signature_instructions"] = self.signature.instructions |
|
state["signature_prefix"] = self.signature.fields[-1].name |
|
|
|
return state |
|
|
|
def load_state(self, state): |
|
for name, value in state.items(): |
|
setattr(self, name, value) |
|
|
|
|
|
if "signature_instructions" in state: |
|
instructions = state["signature_instructions"] |
|
self.signature.instructions = instructions |
|
|
|
if "signature_prefix" in state: |
|
prefix = state["signature_prefix"] |
|
self.signature.fields[-1] = self.signature.fields[-1]._replace(name=prefix) |
|
|
|
def __call__(self, **kwargs): |
|
return self.forward(**kwargs) |
|
|
|
def forward(self, **kwargs): |
|
|
|
new_signature = kwargs.pop("new_signature", None) |
|
signature = kwargs.pop("signature", self.signature) |
|
demos = kwargs.pop("demos", self.demos) |
|
config = dict(**self.config, **kwargs.pop("config", {})) |
|
|
|
|
|
lm = kwargs.pop("lm", self.lm) or dsp.settings.lm |
|
|
|
|
|
temperature = config.get("temperature", None) |
|
temperature = lm.kwargs['temperature'] if temperature is None else temperature |
|
|
|
num_generations = config.get("n", None) |
|
if num_generations is None: |
|
num_generations = lm.kwargs.get('n', lm.kwargs.get('num_generations', None)) |
|
|
|
if (temperature is None or temperature <= 0.15) and num_generations > 1: |
|
config["temperature"] = 0.7 |
|
|
|
|
|
|
|
|
|
x = dsp.Example(demos=demos, **kwargs) |
|
|
|
if new_signature is not None: |
|
signature = dsp.Template(signature.instructions, **new_signature) |
|
|
|
if self.lm is None: |
|
x, C = dsp.generate(signature, **config)(x, stage=self.stage) |
|
else: |
|
with dsp.settings.context(lm=self.lm, query_only=True): |
|
|
|
x, C = dsp.generate(signature, **config)(x, stage=self.stage) |
|
|
|
completions = [] |
|
|
|
for c in C: |
|
completions.append({}) |
|
for field in signature.fields: |
|
if field.output_variable not in kwargs.keys(): |
|
completions[-1][field.output_variable] = getattr(c, field.output_variable) |
|
|
|
pred = Prediction.from_completions(completions, signature=signature) |
|
|
|
if kwargs.pop("_trace", True) and dsp.settings.trace is not None: |
|
trace = dsp.settings.trace |
|
trace.append((self, {**kwargs}, pred)) |
|
|
|
return pred |
|
|
|
def update_config(self, **kwargs): |
|
self.config = {**self.config, **kwargs} |
|
|
|
def get_config(self): |
|
return self.config |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}({self.signature})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|