File size: 5,293 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import dsp
import random

from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import infer_prefix


class Predict(Parameter):
    def __init__(self, signature, **config):
        self.stage = random.randbytes(8).hex()
        self.signature = signature #.signature
        self.config = config
        self.reset()

        # if the signature is a string
        if isinstance(signature, str):
            inputs, outputs = signature.split("->")
            inputs, outputs = inputs.split(","), outputs.split(",")
            inputs, outputs = [field.strip() for field in inputs], [field.strip() for field in outputs]

            assert all(len(field.split()) == 1 for field in (inputs + outputs))

            inputs_ = ', '.join([f"`{field}`" for field in inputs])
            outputs_ = ', '.join([f"`{field}`" for field in outputs])

            instructions = f"""Given the fields {inputs_}, produce the fields {outputs_}."""

            inputs = {k: InputField() for k in inputs}
            outputs = {k: OutputField() for k in outputs}

            for k, v in inputs.items():
                v.finalize(k, infer_prefix(k))
            
            for k, v in outputs.items():
                v.finalize(k, infer_prefix(k))

            self.signature = dsp.Template(instructions, **inputs, **outputs)

    
    def reset(self):
        self.lm = None
        self.traces = []
        self.train = []
        self.demos = []

    def dump_state(self):
        state_keys = ["lm", "traces", "train", "demos"]
        state = {k: getattr(self, k) for k in state_keys}

        # Cache the signature instructions and the last field's name.
        state["signature_instructions"] = self.signature.instructions
        state["signature_prefix"] = self.signature.fields[-1].name

        return state

    def load_state(self, state):
        for name, value in state.items():
            setattr(self, name, value)
        
        # Reconstruct the signature.
        if "signature_instructions" in state:
            instructions = state["signature_instructions"]
            self.signature.instructions = instructions
        
        if "signature_prefix" in state:
            prefix = state["signature_prefix"]
            self.signature.fields[-1] = self.signature.fields[-1]._replace(name=prefix)
    
    def __call__(self, **kwargs):
        return self.forward(**kwargs)
    
    def forward(self, **kwargs):
        # Extract the three privileged keyword arguments.
        new_signature = kwargs.pop("new_signature", None)
        signature = kwargs.pop("signature", self.signature)
        demos = kwargs.pop("demos", self.demos)
        config = dict(**self.config, **kwargs.pop("config", {}))

        # Get the right LM to use.
        lm = kwargs.pop("lm", self.lm) or dsp.settings.lm

        # If temperature is 0.0 but its n > 1, set temperature to 0.7.
        temperature = config.get("temperature", None)
        temperature = lm.kwargs['temperature'] if temperature is None else temperature

        num_generations = config.get("n", None)
        if num_generations is None:
            num_generations = lm.kwargs.get('n', lm.kwargs.get('num_generations', None))

        if (temperature is None or temperature <= 0.15) and num_generations > 1:
            config["temperature"] = 0.7
            # print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")

        # All of the other kwargs are presumed to fit a prefix of the signature.

        x = dsp.Example(demos=demos, **kwargs)

        if new_signature is not None:
            signature = dsp.Template(signature.instructions, **new_signature)

        if self.lm is None:
            x, C = dsp.generate(signature, **config)(x, stage=self.stage)
        else:
            with dsp.settings.context(lm=self.lm, query_only=True):
                # print(f"using lm = {self.lm} !")
                x, C = dsp.generate(signature, **config)(x, stage=self.stage)

        completions = []

        for c in C:
            completions.append({})
            for field in signature.fields:
                if field.output_variable not in kwargs.keys():
                    completions[-1][field.output_variable] = getattr(c, field.output_variable)

        pred = Prediction.from_completions(completions, signature=signature)
            
        if kwargs.pop("_trace", True) and dsp.settings.trace is not None:
            trace = dsp.settings.trace
            trace.append((self, {**kwargs}, pred))

        return pred

    def update_config(self, **kwargs):
        self.config = {**self.config, **kwargs}
    
    def get_config(self):
        return self.config

    def __repr__(self):
        return f"{self.__class__.__name__}({self.signature})"



# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
# Generally, unless overwritten, we'd see n=None, temperature=None.
# That will eventually mean we have to learn them.