File size: 5,952 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import copy
import random
import dsp
import dspy
from dspy.predict.parameter import Parameter
from dspy.predict.predict import Predict
from dspy.primitives.prediction import Prediction
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import infer_prefix
from langchain_core.pydantic_v1 import Extra
from langchain_core.runnables import Runnable
class Template2Signature(dspy.Signature):
"""You are a processor for prompts. I will give you a prompt template (Python f-string) for an arbitrary task for other LMs.
Your job is to prepare three modular pieces: (i) any essential task instructions or guidelines, (ii) a list of variable names for inputs, (iv) the variable name for output."""
template = dspy.InputField(format=lambda x: f"```\n\n{x.strip()}\n\n```\n\nLet's now prepare three modular pieces.")
essential_instructions = dspy.OutputField()
input_keys = dspy.OutputField(desc='comma-separated list of valid variable names')
output_key = dspy.OutputField(desc='a valid variable name')
class ShallowCopyOnly:
def __init__(self, obj): self.obj = obj
def __getattr__(self, item): return getattr(self.obj, item)
def __deepcopy__(self, memo): return ShallowCopyOnly(copy.copy(self.obj))
class LangChainPredict(Predict, Runnable): #, RunnableBinding):
class Config: extra = Extra.allow # Allow extra attributes that are not defined in the model
def __init__(self, prompt, llm, **config):
Runnable.__init__(self)
Parameter.__init__(self)
self.langchain_llm = ShallowCopyOnly(llm)
try: langchain_template = '\n'.join([msg.prompt.template for msg in prompt.messages])
except AttributeError: langchain_template = prompt.template
self.stage = random.randbytes(8).hex()
self.signature, self.output_field_key = self._build_signature(langchain_template)
self.config = config
self.reset()
def reset(self):
self.lm = None
self.traces = []
self.train = []
self.demos = []
def dump_state(self):
state_keys = ["lm", "traces", "train", "demos"]
return {k: getattr(self, k) for k in state_keys}
def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)
self.demos = [dspy.Example(**x) for x in self.demos]
def __call__(self, *arg, **kwargs):
if len(arg) > 0: kwargs = {**arg[0], **kwargs}
return self.forward(**kwargs)
def _build_signature(self, template):
gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=4000, model_type='chat')
with dspy.context(lm=gpt4T): parts = dspy.Predict(Template2Signature)(template=template)
inputs = {k.strip(): InputField() for k in parts.input_keys.split(',')}
outputs = {k.strip(): OutputField() for k in parts.output_key.split(',')}
for k, v in inputs.items():
v.finalize(k, infer_prefix(k)) # TODO: Generate from the template at dspy.Predict(Template2Signature)
for k, v in outputs.items():
output_field_key = k
v.finalize(k, infer_prefix(k))
return dsp.Template(parts.essential_instructions, **inputs, **outputs), output_field_key
def forward(self, **kwargs):
# Extract the three privileged keyword arguments.
signature = kwargs.pop("signature", self.signature)
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))
prompt = signature(dsp.Example(demos=demos, **kwargs))
output = self.langchain_llm.invoke(prompt, **config)
try: content = output.content
except AttributeError: content = output
pred = Prediction.from_completions([{self.output_field_key: content}], signature=signature)
# print('#> len(demos) =', len(demos))
# print(f"#> {prompt}")
# print(f"#> PRED = {content}\n\n\n")
dspy.settings.langchain_history.append((prompt, pred))
if dsp.settings.trace is not None:
trace = dsp.settings.trace
trace.append((self, {**kwargs}, pred))
return output
def invoke(self, d, *args, **kwargs):
# print(d)
return self.forward(**d)
# Almost good but need output parsing for the fields!
# TODO: Use template.extract(example, p)
# class LangChainOfThought(LangChainPredict):
# def __init__(self, signature, **config):
# super().__init__(signature, **config)
# signature = self.signature
# *keys, last_key = signature.kwargs.keys()
# rationale_type = dsp.Type(prefix="Reasoning: Let's think step by step in order to",
# desc="${produce the " + last_key + "}. We ...")
# extended_kwargs = {key: signature.kwargs[key] for key in keys}
# extended_kwargs.update({"rationale": rationale_type, last_key: signature.kwargs[last_key]})
# self.extended_signature = dsp.Template(signature.instructions, **extended_kwargs)
# def forward(self, **kwargs):
# signature = self.extended_signature
# return super().forward(signature=signature, **kwargs)
class LangChainModule(dspy.Module):
def __init__(self, lcel):
super().__init__()
modules = []
for name, node in lcel.get_graph().nodes.items():
if isinstance(node.data, LangChainPredict): modules.append(node.data)
self.modules = modules
self.chain = lcel
def forward(self, **kwargs):
output_keys = ['output', self.modules[-1].output_field_key]
output = self.chain.invoke(dict(**kwargs))
try: output = output.content
except Exception: pass
return dspy.Prediction({k: output for k in output_keys})
def invoke(self, d, *args, **kwargs):
return self.forward(**d).output
|