File size: 5,952 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import copy
import random

import dsp
import dspy

from dspy.predict.parameter import Parameter
from dspy.predict.predict import Predict
from dspy.primitives.prediction import Prediction
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import infer_prefix

from langchain_core.pydantic_v1 import Extra
from langchain_core.runnables import Runnable


class Template2Signature(dspy.Signature):
    """You are a processor for prompts. I will give you a prompt template (Python f-string) for an arbitrary task for other LMs.
Your job is to prepare three modular pieces: (i) any essential task instructions or guidelines, (ii) a list of variable names for inputs, (iv) the variable name for output."""

    template = dspy.InputField(format=lambda x: f"```\n\n{x.strip()}\n\n```\n\nLet's now prepare three modular pieces.")
    essential_instructions = dspy.OutputField()
    input_keys = dspy.OutputField(desc='comma-separated list of valid variable names')
    output_key = dspy.OutputField(desc='a valid variable name')


class ShallowCopyOnly:
    def __init__(self, obj): self.obj = obj
    def __getattr__(self, item): return getattr(self.obj, item)
    def __deepcopy__(self, memo): return ShallowCopyOnly(copy.copy(self.obj))


class LangChainPredict(Predict, Runnable): #, RunnableBinding):
    class Config: extra = Extra.allow  # Allow extra attributes that are not defined in the model

    def __init__(self, prompt, llm, **config):
        Runnable.__init__(self)
        Parameter.__init__(self)

        self.langchain_llm = ShallowCopyOnly(llm)

        try: langchain_template = '\n'.join([msg.prompt.template for msg in prompt.messages])
        except AttributeError: langchain_template = prompt.template

        self.stage = random.randbytes(8).hex()
        self.signature, self.output_field_key = self._build_signature(langchain_template)
        self.config = config
        self.reset()

    def reset(self):
        self.lm = None
        self.traces = []
        self.train = []
        self.demos = []

    def dump_state(self):
        state_keys = ["lm", "traces", "train", "demos"]
        return {k: getattr(self, k) for k in state_keys}

    def load_state(self, state):
        for name, value in state.items():
            setattr(self, name, value)

        self.demos = [dspy.Example(**x) for x in self.demos]
    
    def __call__(self, *arg, **kwargs):
        if len(arg) > 0: kwargs = {**arg[0], **kwargs}
        return self.forward(**kwargs)
    
    def _build_signature(self, template):
        gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=4000, model_type='chat')

        with dspy.context(lm=gpt4T): parts = dspy.Predict(Template2Signature)(template=template)

        inputs = {k.strip(): InputField() for k in parts.input_keys.split(',')}
        outputs = {k.strip(): OutputField() for k in parts.output_key.split(',')}

        for k, v in inputs.items():
            v.finalize(k, infer_prefix(k))  # TODO: Generate from the template at dspy.Predict(Template2Signature)

        for k, v in outputs.items():
            output_field_key = k
            v.finalize(k, infer_prefix(k))

        return dsp.Template(parts.essential_instructions, **inputs, **outputs), output_field_key

    def forward(self, **kwargs):
        # Extract the three privileged keyword arguments.
        signature = kwargs.pop("signature", self.signature)
        demos = kwargs.pop("demos", self.demos)
        config = dict(**self.config, **kwargs.pop("config", {}))

        prompt = signature(dsp.Example(demos=demos, **kwargs))
        output = self.langchain_llm.invoke(prompt, **config)

        try: content = output.content
        except AttributeError: content = output

        pred = Prediction.from_completions([{self.output_field_key: content}], signature=signature)

        # print('#> len(demos) =', len(demos))
        # print(f"#> {prompt}")
        # print(f"#> PRED = {content}\n\n\n")
        dspy.settings.langchain_history.append((prompt, pred))
            
        if dsp.settings.trace is not None:
            trace = dsp.settings.trace
            trace.append((self, {**kwargs}, pred))

        return output
    
    def invoke(self, d, *args, **kwargs):
        # print(d)
        return self.forward(**d)


# Almost good but need output parsing for the fields!
# TODO: Use template.extract(example, p)

# class LangChainOfThought(LangChainPredict):
#     def __init__(self, signature, **config):
#         super().__init__(signature, **config)

#         signature = self.signature
#         *keys, last_key = signature.kwargs.keys()
#         rationale_type = dsp.Type(prefix="Reasoning: Let's think step by step in order to",
#                                   desc="${produce the " + last_key + "}. We ...")

#         extended_kwargs = {key: signature.kwargs[key] for key in keys}
#         extended_kwargs.update({"rationale": rationale_type, last_key: signature.kwargs[last_key]})
#         self.extended_signature = dsp.Template(signature.instructions, **extended_kwargs)

#     def forward(self, **kwargs):
#         signature = self.extended_signature
#         return super().forward(signature=signature, **kwargs)


class LangChainModule(dspy.Module):
    def __init__(self, lcel):
        super().__init__()
        
        modules = []
        for name, node in lcel.get_graph().nodes.items():
            if isinstance(node.data, LangChainPredict): modules.append(node.data)

        self.modules = modules
        self.chain = lcel
    
    def forward(self, **kwargs):
        output_keys = ['output', self.modules[-1].output_field_key]
        output = self.chain.invoke(dict(**kwargs))
        
        try: output = output.content
        except Exception: pass

        return dspy.Prediction({k: output for k in output_keys})
    
    def invoke(self, d, *args, **kwargs):
        return self.forward(**d).output