File size: 5,293 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import dsp
import random
from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import infer_prefix
class Predict(Parameter):
def __init__(self, signature, **config):
self.stage = random.randbytes(8).hex()
self.signature = signature #.signature
self.config = config
self.reset()
# if the signature is a string
if isinstance(signature, str):
inputs, outputs = signature.split("->")
inputs, outputs = inputs.split(","), outputs.split(",")
inputs, outputs = [field.strip() for field in inputs], [field.strip() for field in outputs]
assert all(len(field.split()) == 1 for field in (inputs + outputs))
inputs_ = ', '.join([f"`{field}`" for field in inputs])
outputs_ = ', '.join([f"`{field}`" for field in outputs])
instructions = f"""Given the fields {inputs_}, produce the fields {outputs_}."""
inputs = {k: InputField() for k in inputs}
outputs = {k: OutputField() for k in outputs}
for k, v in inputs.items():
v.finalize(k, infer_prefix(k))
for k, v in outputs.items():
v.finalize(k, infer_prefix(k))
self.signature = dsp.Template(instructions, **inputs, **outputs)
def reset(self):
self.lm = None
self.traces = []
self.train = []
self.demos = []
def dump_state(self):
state_keys = ["lm", "traces", "train", "demos"]
state = {k: getattr(self, k) for k in state_keys}
# Cache the signature instructions and the last field's name.
state["signature_instructions"] = self.signature.instructions
state["signature_prefix"] = self.signature.fields[-1].name
return state
def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)
# Reconstruct the signature.
if "signature_instructions" in state:
instructions = state["signature_instructions"]
self.signature.instructions = instructions
if "signature_prefix" in state:
prefix = state["signature_prefix"]
self.signature.fields[-1] = self.signature.fields[-1]._replace(name=prefix)
def __call__(self, **kwargs):
return self.forward(**kwargs)
def forward(self, **kwargs):
# Extract the three privileged keyword arguments.
new_signature = kwargs.pop("new_signature", None)
signature = kwargs.pop("signature", self.signature)
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))
# Get the right LM to use.
lm = kwargs.pop("lm", self.lm) or dsp.settings.lm
# If temperature is 0.0 but its n > 1, set temperature to 0.7.
temperature = config.get("temperature", None)
temperature = lm.kwargs['temperature'] if temperature is None else temperature
num_generations = config.get("n", None)
if num_generations is None:
num_generations = lm.kwargs.get('n', lm.kwargs.get('num_generations', None))
if (temperature is None or temperature <= 0.15) and num_generations > 1:
config["temperature"] = 0.7
# print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")
# All of the other kwargs are presumed to fit a prefix of the signature.
x = dsp.Example(demos=demos, **kwargs)
if new_signature is not None:
signature = dsp.Template(signature.instructions, **new_signature)
if self.lm is None:
x, C = dsp.generate(signature, **config)(x, stage=self.stage)
else:
with dsp.settings.context(lm=self.lm, query_only=True):
# print(f"using lm = {self.lm} !")
x, C = dsp.generate(signature, **config)(x, stage=self.stage)
completions = []
for c in C:
completions.append({})
for field in signature.fields:
if field.output_variable not in kwargs.keys():
completions[-1][field.output_variable] = getattr(c, field.output_variable)
pred = Prediction.from_completions(completions, signature=signature)
if kwargs.pop("_trace", True) and dsp.settings.trace is not None:
trace = dsp.settings.trace
trace.append((self, {**kwargs}, pred))
return pred
def update_config(self, **kwargs):
self.config = {**self.config, **kwargs}
def get_config(self):
return self.config
def __repr__(self):
return f"{self.__class__.__name__}({self.signature})"
# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
# Generally, unless overwritten, we'd see n=None, temperature=None.
# That will eventually mean we have to learn them.
|