import copy import random import dsp import dspy from dspy.predict.parameter import Parameter from dspy.predict.predict import Predict from dspy.primitives.prediction import Prediction from dspy.signatures.field import InputField, OutputField from dspy.signatures.signature import infer_prefix from langchain_core.pydantic_v1 import Extra from langchain_core.runnables import Runnable class Template2Signature(dspy.Signature): """You are a processor for prompts. I will give you a prompt template (Python f-string) for an arbitrary task for other LMs. Your job is to prepare three modular pieces: (i) any essential task instructions or guidelines, (ii) a list of variable names for inputs, (iv) the variable name for output.""" template = dspy.InputField(format=lambda x: f"```\n\n{x.strip()}\n\n```\n\nLet's now prepare three modular pieces.") essential_instructions = dspy.OutputField() input_keys = dspy.OutputField(desc='comma-separated list of valid variable names') output_key = dspy.OutputField(desc='a valid variable name') class ShallowCopyOnly: def __init__(self, obj): self.obj = obj def __getattr__(self, item): return getattr(self.obj, item) def __deepcopy__(self, memo): return ShallowCopyOnly(copy.copy(self.obj)) class LangChainPredict(Predict, Runnable): #, RunnableBinding): class Config: extra = Extra.allow # Allow extra attributes that are not defined in the model def __init__(self, prompt, llm, **config): Runnable.__init__(self) Parameter.__init__(self) self.langchain_llm = ShallowCopyOnly(llm) try: langchain_template = '\n'.join([msg.prompt.template for msg in prompt.messages]) except AttributeError: langchain_template = prompt.template self.stage = random.randbytes(8).hex() self.signature, self.output_field_key = self._build_signature(langchain_template) self.config = config self.reset() def reset(self): self.lm = None self.traces = [] self.train = [] self.demos = [] def dump_state(self): state_keys = ["lm", "traces", "train", "demos"] return {k: getattr(self, k) for k in state_keys} def load_state(self, state): for name, value in state.items(): setattr(self, name, value) self.demos = [dspy.Example(**x) for x in self.demos] def __call__(self, *arg, **kwargs): if len(arg) > 0: kwargs = {**arg[0], **kwargs} return self.forward(**kwargs) def _build_signature(self, template): gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=4000, model_type='chat') with dspy.context(lm=gpt4T): parts = dspy.Predict(Template2Signature)(template=template) inputs = {k.strip(): InputField() for k in parts.input_keys.split(',')} outputs = {k.strip(): OutputField() for k in parts.output_key.split(',')} for k, v in inputs.items(): v.finalize(k, infer_prefix(k)) # TODO: Generate from the template at dspy.Predict(Template2Signature) for k, v in outputs.items(): output_field_key = k v.finalize(k, infer_prefix(k)) return dsp.Template(parts.essential_instructions, **inputs, **outputs), output_field_key def forward(self, **kwargs): # Extract the three privileged keyword arguments. signature = kwargs.pop("signature", self.signature) demos = kwargs.pop("demos", self.demos) config = dict(**self.config, **kwargs.pop("config", {})) prompt = signature(dsp.Example(demos=demos, **kwargs)) output = self.langchain_llm.invoke(prompt, **config) try: content = output.content except AttributeError: content = output pred = Prediction.from_completions([{self.output_field_key: content}], signature=signature) # print('#> len(demos) =', len(demos)) # print(f"#> {prompt}") # print(f"#> PRED = {content}\n\n\n") dspy.settings.langchain_history.append((prompt, pred)) if dsp.settings.trace is not None: trace = dsp.settings.trace trace.append((self, {**kwargs}, pred)) return output def invoke(self, d, *args, **kwargs): # print(d) return self.forward(**d) # Almost good but need output parsing for the fields! # TODO: Use template.extract(example, p) # class LangChainOfThought(LangChainPredict): # def __init__(self, signature, **config): # super().__init__(signature, **config) # signature = self.signature # *keys, last_key = signature.kwargs.keys() # rationale_type = dsp.Type(prefix="Reasoning: Let's think step by step in order to", # desc="${produce the " + last_key + "}. We ...") # extended_kwargs = {key: signature.kwargs[key] for key in keys} # extended_kwargs.update({"rationale": rationale_type, last_key: signature.kwargs[last_key]}) # self.extended_signature = dsp.Template(signature.instructions, **extended_kwargs) # def forward(self, **kwargs): # signature = self.extended_signature # return super().forward(signature=signature, **kwargs) class LangChainModule(dspy.Module): def __init__(self, lcel): super().__init__() modules = [] for name, node in lcel.get_graph().nodes.items(): if isinstance(node.data, LangChainPredict): modules.append(node.data) self.modules = modules self.chain = lcel def forward(self, **kwargs): output_keys = ['output', self.modules[-1].output_field_key] output = self.chain.invoke(dict(**kwargs)) try: output = output.content except Exception: pass return dspy.Prediction({k: output for k in output_keys}) def invoke(self, d, *args, **kwargs): return self.forward(**d).output