EureCA / dspy /predict /program_of_thought.py
tonneli's picture
Delete history
f5776d3
import dsp
import dspy
from ..primitives.program import Module
from ..primitives.python_interpreter import CodePrompt, PythonInterpreter
import re
class ProgramOfThought(Module):
def __init__(self, signature, max_iters=3):
super().__init__()
self.signature = signature = dspy.Predict(signature).signature
self.max_iters = max_iters
self.input_fields = signature.input_fields()
self.output_fields = signature.output_fields()
inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()])
outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()])
assert len(self.output_fields) == 1, "PoT only supports one output field."
instr = []
instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.")
instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.")
instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.")
instr = '\n'.join(instr)
self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate')))
self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate')))
self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer')))
def _generate_signature(self, mode):
signature_dict = dict(self.input_fields)
fields_for_mode = {
'generate': {
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
},
'regenerate': {
'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str),
'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"),
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
},
'answer': {
'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str),
'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"),
'answer': self.signature.kwargs["answer"]
}
}
signature_dict.update(fields_for_mode[mode])
return signature_dict
def _generate_instruction(self, mode):
mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)])
mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)])
if mode == 'generate':
instr = [
f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}."
]
elif mode == 'regenerate':
instr = [
f"You are given {mode_inputs} due to an error in previous code.",
f"Your task is to correct the error and provide the new {mode_outputs}."
]
else: # mode == 'answer'
instr = [
f"Given the final code {mode_inputs}, provide the final {mode_outputs}."
]
return '\n'.join(instr)
def parse_code(self, code_data):
code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0]
code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL)
code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n')
if not code_block:
return code, "Error: Empty code after parsing."
if "\n" not in code_block and code_block.count('=') > 1:
return code, "Error: Code format is not correct."
lines = code_block.split('\n')
last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip())
if last_line_match and len(lines) > 1:
code_block += '\n' + last_line_match.group(1)
else:
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block)
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block)
return code_block, None
def execute_code(self, code):
if not code:
return code, None, 'Error: Empty code before execution.'
code_prompt = CodePrompt(code, code_type="python")
interpreter = PythonInterpreter(action_space={"print": print})
try:
output = str(code_prompt.execute(interpreter=interpreter)[0])
return code, output, None
except Exception as e:
return code, None, str(e)
def forward(self, **kwargs):
code_data = self.code_generate(question=kwargs["question"])
parsed_code, error = self.parse_code(code_data)
code, output, error = self.execute_code(parsed_code)
hop = 0
while hop < self.max_iters and error:
print('Error in code execution')
code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error)
parsed_code, error = self.parse_code(code_data)
hop += 1
if hop == self.max_iters:
print('Max hops reached. Error persists.')
return None
answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output)
return answer_gen_result