Spaces:
Runtime error
Runtime error
| from time import sleep | |
| import ast | |
| import astunparse | |
| import openai | |
| from openai.error import RateLimitError, APIConnectionError | |
| from pygments import highlight | |
| from pygments.lexers import PythonLexer | |
| from pygments.formatters import TerminalFormatter | |
| class LMP: | |
| def __init__(self, name, cfg, lmp_fgen, fixed_vars, variable_vars, md_logger): | |
| self._name = name | |
| self._cfg = cfg | |
| self._md_logger = md_logger | |
| with open(self._cfg['prompt_path'], 'r') as f: | |
| self._base_prompt = f.read() | |
| self._stop_tokens = list(self._cfg['stop']) | |
| self._lmp_fgen = lmp_fgen | |
| self._fixed_vars = fixed_vars | |
| self._variable_vars = variable_vars | |
| self.exec_hist = '' | |
| def clear_exec_hist(self): | |
| self.exec_hist = '' | |
| def build_prompt(self, query, context=''): | |
| if len(self._variable_vars) > 0: | |
| variable_vars_imports_str = f"from utils import {', '.join(self._variable_vars.keys())}" | |
| else: | |
| variable_vars_imports_str = '' | |
| prompt = self._base_prompt.replace('{variable_vars_imports}', variable_vars_imports_str) | |
| if self._cfg['maintain_session']: | |
| prompt += f'\n{self.exec_hist}' | |
| if context != '': | |
| prompt += f'\n{context}' | |
| use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}' | |
| prompt += f'\n{use_query}' | |
| return prompt, use_query | |
| def __call__(self, query, context='', **kwargs): | |
| prompt, use_query = self.build_prompt(query, context=context) | |
| while True: | |
| try: | |
| code_str = openai.Completion.create( | |
| prompt=prompt, | |
| stop=self._stop_tokens, | |
| temperature=self._cfg['temperature'], | |
| engine=self._cfg['engine'], | |
| max_tokens=self._cfg['max_tokens'] | |
| )['choices'][0]['text'].strip() | |
| break | |
| except (RateLimitError, APIConnectionError) as e: | |
| print(f'OpenAI API got err {e}') | |
| print('Retrying after 10s.') | |
| sleep(10) | |
| if self._cfg['include_context'] and context != '': | |
| to_exec = f'{context}\n{code_str}' | |
| to_log = f'{context}\n{use_query}\n{code_str}' | |
| else: | |
| to_exec = code_str | |
| to_log = f'{use_query}\n{to_exec}' | |
| to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) | |
| print(f'LMP {self._name} generated code:\n{to_log_pretty}') | |
| self._md_logger.log_text(f'LMP {self._name} Generated Code:') | |
| self._md_logger.log_code(to_log) | |
| new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) | |
| self._variable_vars.update(new_fs) | |
| gvars = merge_dicts([self._fixed_vars, self._variable_vars]) | |
| lvars = kwargs | |
| if not self._cfg['debug_mode']: | |
| exec_safe(to_exec, gvars, lvars) | |
| self.exec_hist += f'\n{to_exec}' | |
| if self._cfg['maintain_session']: | |
| self._variable_vars.update(lvars) | |
| if self._cfg['has_return']: | |
| return lvars[self._cfg['return_val_name']] | |
| class LMPFGen: | |
| def __init__(self, cfg, fixed_vars, variable_vars, md_logger): | |
| self._cfg = cfg | |
| self._stop_tokens = list(self._cfg['stop']) | |
| self._fixed_vars = fixed_vars | |
| self._variable_vars = variable_vars | |
| self._md_logger = md_logger | |
| with open(self._cfg['prompt_path'], 'r') as f: | |
| self._base_prompt = f.read() | |
| def create_f_from_sig(self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False): | |
| print(f'Creating function: {f_sig}') | |
| use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}' | |
| prompt = f'{self._base_prompt}\n{use_query}' | |
| while True: | |
| try: | |
| f_src = openai.Completion.create( | |
| prompt=prompt, | |
| stop=self._stop_tokens, | |
| temperature=self._cfg['temperature'], | |
| engine=self._cfg['engine'], | |
| max_tokens=self._cfg['max_tokens'] | |
| )['choices'][0]['text'].strip() | |
| break | |
| except (RateLimitError, APIConnectionError) as e: | |
| print(f'OpenAI API got err {e}') | |
| print('Retrying after 10s.') | |
| sleep(10) | |
| if fix_bugs: | |
| f_src = openai.Edit.create( | |
| model='code-davinci-edit-001', | |
| input='# ' + f_src, | |
| temperature=0, | |
| instruction='Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.', | |
| )['choices'][0]['text'].strip() | |
| if other_vars is None: | |
| other_vars = {} | |
| gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars]) | |
| lvars = {} | |
| exec_safe(f_src, gvars, lvars) | |
| f = lvars[f_name] | |
| to_print = f'{use_query}\n{f_src}' | |
| to_print_pretty = highlight(to_print, PythonLexer(), TerminalFormatter()) | |
| print(f'LMPFGen generated code:\n{to_print_pretty}') | |
| self._md_logger.log_text('Generated Function:') | |
| self._md_logger.log_code(to_print) | |
| if return_src: | |
| return f, f_src | |
| return f | |
| def create_new_fs_from_code(self, code_str, other_vars=None, fix_bugs=False, return_src=False): | |
| fs, f_assigns = {}, {} | |
| f_parser = FunctionParser(fs, f_assigns) | |
| f_parser.visit(ast.parse(code_str)) | |
| for f_name, f_assign in f_assigns.items(): | |
| if f_name in fs: | |
| fs[f_name] = f_assign | |
| if other_vars is None: | |
| other_vars = {} | |
| new_fs = {} | |
| srcs = {} | |
| for f_name, f_sig in fs.items(): | |
| all_vars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) | |
| if not var_exists(f_name, all_vars): | |
| f, f_src = self.create_f_from_sig(f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True) | |
| # recursively define child_fs in the function body if needed | |
| f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body) | |
| child_fs, child_f_srcs = self.create_new_fs_from_code( | |
| f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True | |
| ) | |
| if len(child_fs) > 0: | |
| new_fs.update(child_fs) | |
| srcs.update(child_f_srcs) | |
| # redefine parent f so newly created child_fs are in scope | |
| gvars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) | |
| lvars = {} | |
| exec_safe(f_src, gvars, lvars) | |
| f = lvars[f_name] | |
| new_fs[f_name], srcs[f_name] = f, f_src | |
| if return_src: | |
| return new_fs, srcs | |
| return new_fs | |
| class FunctionParser(ast.NodeTransformer): | |
| def __init__(self, fs, f_assigns): | |
| super().__init__() | |
| self._fs = fs | |
| self._f_assigns = f_assigns | |
| def visit_Call(self, node): | |
| self.generic_visit(node) | |
| if isinstance(node.func, ast.Name): | |
| f_sig = astunparse.unparse(node).strip() | |
| f_name = astunparse.unparse(node.func).strip() | |
| self._fs[f_name] = f_sig | |
| return node | |
| def visit_Assign(self, node): | |
| self.generic_visit(node) | |
| if isinstance(node.value, ast.Call): | |
| assign_str = astunparse.unparse(node).strip() | |
| f_name = astunparse.unparse(node.value.func).strip() | |
| self._f_assigns[f_name] = assign_str | |
| return node | |
| def var_exists(name, all_vars): | |
| try: | |
| eval(name, all_vars) | |
| except: | |
| exists = False | |
| else: | |
| exists = True | |
| return exists | |
| def merge_dicts(dicts): | |
| return { | |
| k : v | |
| for d in dicts | |
| for k, v in d.items() | |
| } | |
| def exec_safe(code_str, gvars=None, lvars=None): | |
| banned_phrases = ['import', '__'] | |
| for phrase in banned_phrases: | |
| assert phrase not in code_str | |
| if gvars is None: | |
| gvars = {} | |
| if lvars is None: | |
| lvars = {} | |
| empty_fn = lambda *args, **kwargs: None | |
| custom_gvars = merge_dicts([ | |
| gvars, | |
| {'exec': empty_fn, 'eval': empty_fn} | |
| ]) | |
| exec(code_str, custom_gvars, lvars) |