File size: 6,244 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import dsp
import dspy
from ..primitives.program import Module
from ..primitives.python_interpreter import CodePrompt, PythonInterpreter
import re

class ProgramOfThought(Module):
    def __init__(self, signature, max_iters=3):
        super().__init__()
        self.signature = signature = dspy.Predict(signature).signature
        self.max_iters = max_iters

        self.input_fields = signature.input_fields()
        self.output_fields = signature.output_fields()

        inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()])
        outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()])

        assert len(self.output_fields) == 1, "PoT only supports one output field."
        
        instr = []
        instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.")
        instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.")
        instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.")
        instr = '\n'.join(instr)
        
        self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate')))
        self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate')))
        self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer')))

    def _generate_signature(self, mode):
        signature_dict = dict(self.input_fields)
        fields_for_mode = {
            'generate': {
                'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
            },
            'regenerate': {
                'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str),
                'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"),
                'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
            },
            'answer': {
                'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str),
                'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"),
                'answer': self.signature.kwargs["answer"]
            }
        }
        signature_dict.update(fields_for_mode[mode])
        return signature_dict

    def _generate_instruction(self, mode):
        mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)])
        mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)])
        if mode == 'generate':
            instr = [
                f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
                f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
                f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}."
            ]
        elif mode == 'regenerate':
            instr = [
                f"You are given {mode_inputs} due to an error in previous code.",
                f"Your task is to correct the error and provide the new {mode_outputs}."
            ]
        else:  # mode == 'answer'
            instr = [
                f"Given the final code {mode_inputs}, provide the final {mode_outputs}."
            ]

        return '\n'.join(instr)

    def parse_code(self, code_data):
        code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0]
        code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL)
        code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n')
        if not code_block:
            return code, "Error: Empty code after parsing."
        if "\n" not in code_block and code_block.count('=') > 1:
            return code, "Error: Code format is not correct."
        lines = code_block.split('\n')
        last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip())
        if last_line_match and len(lines) > 1:
            code_block += '\n' + last_line_match.group(1)
        else:
            code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block)
            code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block)
        return code_block, None

    def execute_code(self, code):
        if not code:
            return code, None, 'Error: Empty code before execution.'
        code_prompt = CodePrompt(code, code_type="python")
        interpreter = PythonInterpreter(action_space={"print": print})
        try:
            output = str(code_prompt.execute(interpreter=interpreter)[0])
            return code, output, None
        except Exception as e:
            return code, None, str(e)
            
    def forward(self, **kwargs):
        code_data = self.code_generate(question=kwargs["question"])
        parsed_code, error = self.parse_code(code_data)
        code, output, error = self.execute_code(parsed_code)
        hop = 0
        while hop < self.max_iters and error:
            print('Error in code execution')
            code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error)
            parsed_code, error = self.parse_code(code_data)
            hop += 1
            if hop == self.max_iters:
                print('Max hops reached. Error persists.')
                return None
        answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output)
        return answer_gen_result