EureCA / dspy /predict /predict.py
tonneli's picture
Delete history
f5776d3
import dsp
import random
from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import infer_prefix
class Predict(Parameter):
def __init__(self, signature, **config):
self.stage = random.randbytes(8).hex()
self.signature = signature #.signature
self.config = config
self.reset()
# if the signature is a string
if isinstance(signature, str):
inputs, outputs = signature.split("->")
inputs, outputs = inputs.split(","), outputs.split(",")
inputs, outputs = [field.strip() for field in inputs], [field.strip() for field in outputs]
assert all(len(field.split()) == 1 for field in (inputs + outputs))
inputs_ = ', '.join([f"`{field}`" for field in inputs])
outputs_ = ', '.join([f"`{field}`" for field in outputs])
instructions = f"""Given the fields {inputs_}, produce the fields {outputs_}."""
inputs = {k: InputField() for k in inputs}
outputs = {k: OutputField() for k in outputs}
for k, v in inputs.items():
v.finalize(k, infer_prefix(k))
for k, v in outputs.items():
v.finalize(k, infer_prefix(k))
self.signature = dsp.Template(instructions, **inputs, **outputs)
def reset(self):
self.lm = None
self.traces = []
self.train = []
self.demos = []
def dump_state(self):
state_keys = ["lm", "traces", "train", "demos"]
state = {k: getattr(self, k) for k in state_keys}
# Cache the signature instructions and the last field's name.
state["signature_instructions"] = self.signature.instructions
state["signature_prefix"] = self.signature.fields[-1].name
return state
def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)
# Reconstruct the signature.
if "signature_instructions" in state:
instructions = state["signature_instructions"]
self.signature.instructions = instructions
if "signature_prefix" in state:
prefix = state["signature_prefix"]
self.signature.fields[-1] = self.signature.fields[-1]._replace(name=prefix)
def __call__(self, **kwargs):
return self.forward(**kwargs)
def forward(self, **kwargs):
# Extract the three privileged keyword arguments.
new_signature = kwargs.pop("new_signature", None)
signature = kwargs.pop("signature", self.signature)
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))
# Get the right LM to use.
lm = kwargs.pop("lm", self.lm) or dsp.settings.lm
# If temperature is 0.0 but its n > 1, set temperature to 0.7.
temperature = config.get("temperature", None)
temperature = lm.kwargs['temperature'] if temperature is None else temperature
num_generations = config.get("n", None)
if num_generations is None:
num_generations = lm.kwargs.get('n', lm.kwargs.get('num_generations', None))
if (temperature is None or temperature <= 0.15) and num_generations > 1:
config["temperature"] = 0.7
# print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")
# All of the other kwargs are presumed to fit a prefix of the signature.
x = dsp.Example(demos=demos, **kwargs)
if new_signature is not None:
signature = dsp.Template(signature.instructions, **new_signature)
if self.lm is None:
x, C = dsp.generate(signature, **config)(x, stage=self.stage)
else:
with dsp.settings.context(lm=self.lm, query_only=True):
# print(f"using lm = {self.lm} !")
x, C = dsp.generate(signature, **config)(x, stage=self.stage)
completions = []
for c in C:
completions.append({})
for field in signature.fields:
if field.output_variable not in kwargs.keys():
completions[-1][field.output_variable] = getattr(c, field.output_variable)
pred = Prediction.from_completions(completions, signature=signature)
if kwargs.pop("_trace", True) and dsp.settings.trace is not None:
trace = dsp.settings.trace
trace.append((self, {**kwargs}, pred))
return pred
def update_config(self, **kwargs):
self.config = {**self.config, **kwargs}
def get_config(self):
return self.config
def __repr__(self):
return f"{self.__class__.__name__}({self.signature})"
# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
# Generally, unless overwritten, we'd see n=None, temperature=None.
# That will eventually mean we have to learn them.