File size: 4,171 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import dsp
import dspy
from ..primitives.program import Module
from .predict import Predict

# TODO: Simplify a lot.
# TODO: Divide Action and Action Input like langchain does for ReAct.


class ReAct(Module):
    def __init__(self, signature, max_iters=5, num_results=3, tools=None):
        super().__init__()
        self.signature = signature = dspy.Predict(signature).signature
        self.max_iters = max_iters

        self.tools = tools or [dspy.Retrieve(k=num_results)]
        self.tools = {tool.name: tool for tool in self.tools} #if isinstance(self.tools, list) else self.tools

        self.input_fields = {k: v for k, v in self.signature.kwargs.items() if isinstance(v, dspy.InputField)}
        self.output_fields = {k: v for k, v in self.signature.kwargs.items() if isinstance(v, dspy.OutputField)}

        inputs, outputs = signature.fields[:-1], signature.fields[-1:]

        inputs_ = ', '.join([f"`{field.input_variable}`" for field in inputs])
        outputs_ = ', '.join([f"`{field.output_variable}`" for field in outputs])

        assert len(outputs) == 1, "ReAct only supports one output field."

        instr = []
        instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.\n")
        instr.append("To do this, you will interleave Thought, Action, and Observation steps.\n")
        instr.append("Thought can reason about the current situation, and Action can be the following types:\n")

        self.tools['Finish'] = dspy.Example(name="Finish", input_variable=outputs_.strip('`'), desc=f"returns the final {outputs_} and finishes the task")

        for idx, tool in enumerate(self.tools):
            tool = self.tools[tool]
            instr.append(f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}")
        
        instr = '\n'.join(instr)
        self.react = [Predict(dsp.Template(instr, **self._generate_signature(i))) for i in range(1, max_iters + 1)]

    def _generate_signature(self, iters):
        signature_dict = {}
        for key, val in self.input_fields.items():
            signature_dict[key] = val

        for j in range(1, iters + 1):
            signature_dict[f"Thought_{j}"] = dspy.OutputField(prefix=f"Thought {j}:", desc="next steps to take based on last observation")

            tool_list = ' or '.join([f"{tool.name}[{tool.input_variable}]" for tool in self.tools.values() if tool.name != 'Finish'])
            signature_dict[f"Action_{j}"] = dspy.OutputField(prefix=f"Action {j}:", desc=f"always either {tool_list} or, when done, Finish[answer]")

            if j < iters:
                signature_dict[f"Observation_{j}"] = dspy.OutputField(prefix=f"Observation {j}:", desc="observations based on action", format=dsp.passages2text)

        return signature_dict
    
    def act(self, output, hop):
        try:
            action = output[f"Action_{hop+1}"]
            action_name, action_val = action.strip().split('\n')[0].split('[', 1)
            action_val = action_val.rsplit(']', 1)[0]

            if action_name == 'Finish': return action_val

            try: 
                output[f"Observation_{hop+1}"] = self.tools[action_name](action_val).passages
            except AttributeError:
                # Handle the case where 'passages' attribute is missing
                # TODO: This is a hacky way to handle this. Need to fix this.
                output[f"Observation_{hop+1}"] = self.tools[action_name](action_val)

        except Exception as e:
            output[f"Observation_{hop+1}"] = "Failed to parse action. Bad formatting or incorrect action name."
        

    def forward(self, **kwargs):
        args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs}

        for hop in range(self.max_iters):
            # with dspy.settings.context(show_guidelines=(i <= 2)):
            output = self.react[hop](**args)
            
            if action_val := self.act(output, hop): break
            args.update(output)

        # assumes only 1 output field for now - TODO: handling for multiple output fields
        return dspy.Prediction(**{list(self.output_fields.keys())[0]: action_val or ''})