import dsp import dspy from ..primitives.program import Module from ..primitives.python_interpreter import CodePrompt, PythonInterpreter import re class ProgramOfThought(Module): def __init__(self, signature, max_iters=3): super().__init__() self.signature = signature = dspy.Predict(signature).signature self.max_iters = max_iters self.input_fields = signature.input_fields() self.output_fields = signature.output_fields() inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()]) outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()]) assert len(self.output_fields) == 1, "PoT only supports one output field." instr = [] instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.") instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.") instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.") instr = '\n'.join(instr) self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate'))) self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate'))) self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer'))) def _generate_signature(self, mode): signature_dict = dict(self.input_fields) fields_for_mode = { 'generate': { 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) }, 'regenerate': { 'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str), 'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"), 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) }, 'answer': { 'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str), 'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"), 'answer': self.signature.kwargs["answer"] } } signature_dict.update(fields_for_mode[mode]) return signature_dict def _generate_instruction(self, mode): mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)]) mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)]) if mode == 'generate': instr = [ f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}." ] elif mode == 'regenerate': instr = [ f"You are given {mode_inputs} due to an error in previous code.", f"Your task is to correct the error and provide the new {mode_outputs}." ] else: # mode == 'answer' instr = [ f"Given the final code {mode_inputs}, provide the final {mode_outputs}." ] return '\n'.join(instr) def parse_code(self, code_data): code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0] code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL) code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n') if not code_block: return code, "Error: Empty code after parsing." if "\n" not in code_block and code_block.count('=') > 1: return code, "Error: Code format is not correct." lines = code_block.split('\n') last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip()) if last_line_match and len(lines) > 1: code_block += '\n' + last_line_match.group(1) else: code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block) code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block) return code_block, None def execute_code(self, code): if not code: return code, None, 'Error: Empty code before execution.' code_prompt = CodePrompt(code, code_type="python") interpreter = PythonInterpreter(action_space={"print": print}) try: output = str(code_prompt.execute(interpreter=interpreter)[0]) return code, output, None except Exception as e: return code, None, str(e) def forward(self, **kwargs): code_data = self.code_generate(question=kwargs["question"]) parsed_code, error = self.parse_code(code_data) code, output, error = self.execute_code(parsed_code) hop = 0 while hop < self.max_iters and error: print('Error in code execution') code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error) parsed_code, error = self.parse_code(code_data) hop += 1 if hop == self.max_iters: print('Max hops reached. Error persists.') return None answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output) return answer_gen_result