|
import os |
|
import re |
|
import json |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from typing import Any |
|
from utils import timeout |
|
from table_bench_eval.custom_python_tool import CustomPythonTool, sanitize_input |
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool |
|
|
|
CODE_PREFIX = """import matplotlib.pyplot as plt |
|
from mplfonts import use_font |
|
import pandas as pd |
|
import numpy as np |
|
import seaborn as sns |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
# Fixing Chinese font issues |
|
use_font("Noto Serif CJK SC")\n""" |
|
|
|
def valid_path(path): |
|
dir = os.path.dirname(path) |
|
if not os.path.exists(dir): |
|
os.makedirs(dir) |
|
|
|
def pre_save_table_to_csv(table): |
|
table_json = [] |
|
for item in table['data']: |
|
row_data = {} |
|
for i in range(len(table['columns'])): |
|
row_data[table['columns'][i]] = item[i] |
|
table_json.append(row_data) |
|
df = pd.DataFrame(table_json) |
|
df.to_csv('table.csv', index=False) |
|
|
|
def extract_final_answer(text): |
|
match = re.search(r'Final Answer:\s*(.*)', text) |
|
if match: |
|
return match.group(1).strip() |
|
return "" |
|
|
|
def parse_final_answer_prediction(prediction): |
|
pattern = r"Final Answer: (.+)" |
|
try: |
|
match = re.search(pattern, prediction, re.IGNORECASE) |
|
if match: |
|
return match.group(1) |
|
else: |
|
return '' |
|
except Exception: |
|
return '' |
|
|
|
def read_json_file(path, filter_func=None): |
|
if os.path.exists(path): |
|
with open(path, 'r', encoding='utf-8') as f: |
|
try: |
|
json_data = json.load(f) |
|
if filter_func is not None: |
|
json_data = list(filter(filter_func, json_data)) |
|
return json_data |
|
except Exception as e: |
|
f.seek(0) |
|
lines = f.readlines() |
|
json_list = [json.loads(line.strip( |
|
)) for line in lines if filter_func is None or filter_func(json.loads(line.strip()))] |
|
return json_list |
|
else: |
|
return None |
|
|
|
|
|
def write_json_to_file(path: str, data: dict, is_json_line: bool = False) -> None: |
|
valid_path(path) |
|
with open(path, 'w', encoding='utf-8') as f: |
|
if is_json_line: |
|
for line in data: |
|
f.write(json.dumps(line, ensure_ascii=False) + '\n') |
|
else: |
|
f.write(json.dumps(data, ensure_ascii=False, indent=4)) |
|
|
|
def parse_python_code(prediction): |
|
pattern1 = r"```python\n(.*?)```" |
|
matches = re.findall(pattern1, prediction, flags=re.S) |
|
if matches: |
|
return matches[-1] |
|
else: |
|
code = "" |
|
if code == "": |
|
match = re.search(r'Action:\s*(.*)\n', prediction) |
|
if match: |
|
return match.group(1) |
|
else: |
|
return code |
|
|
|
|
|
def get_tool(df: Any, df_names=None): |
|
""" |
|
Define python code execute tool |
|
:param df: List[pd.DataFrame] or pd.DataFrame |
|
:return Runnable |
|
""" |
|
tool = PythonAstREPLTool() |
|
if df_names == None: |
|
if isinstance(df, pd.DataFrame): |
|
locals = {"df": df} |
|
else: |
|
locals = {} |
|
for i, dataframe in enumerate(df): |
|
locals[f"df{i + 1}"] = dataframe |
|
else: |
|
locals = {} |
|
for i, dataframe in enumerate(df): |
|
locals[df_names[i]] = dataframe |
|
tool.locals = locals |
|
tool.globals = tool.locals |
|
return tool |
|
|
|
def ensure_last_line_print(code): |
|
|
|
lines = code.strip().split('\n') |
|
|
|
|
|
last_line = lines[-1].strip() |
|
|
|
|
|
if not last_line.startswith('print'): |
|
|
|
|
|
|
|
last_line_variable = last_line |
|
|
|
|
|
lines[-1] = f'print({last_line_variable})' |
|
|
|
|
|
modified_code = '\n'.join(lines) |
|
return modified_code |
|
|
|
def build_chart_eval_code(sample): |
|
answer = sample['answer'] |
|
chart_type = sample['chart_type'] |
|
prediction = sample['raw_generation'] |
|
|
|
python_code = parse_python_code(prediction) |
|
python_code = CODE_PREFIX + python_code |
|
|
|
|
|
eval_code = ''' |
|
if chart_type == 'line': |
|
y_predictions = get_line_y_predictions(plt) |
|
if chart_type == 'bar': |
|
y_predictions = get_bar_y_predictions(plt) |
|
if chart_type == 'hbar': |
|
y_predictions = get_hbar_y_predictions(plt) |
|
if chart_type == 'pie': |
|
y_predictions = get_pie_y_predictions(plt) |
|
if chart_type == 'area': |
|
y_predictions = get_area_y_predictions(plt) |
|
if chart_type == 'radar': |
|
y_predictions = get_radar_y_predictions(plt) |
|
if chart_type == 'scatter': |
|
y_predictions = get_scatter_y_predictions(plt) |
|
if chart_type == 'waterfall': |
|
y_predictions = get_waterfall_y_predictions(plt) |
|
|
|
if chart_type == 'pie': |
|
print(compute_pie_chart_metric(y_references, y_predictions)) |
|
else: |
|
print(compute_general_chart_metric(y_references, y_predictions)) |
|
''' |
|
|
|
|
|
y_ref_str = f"{answer}" |
|
chart_type_str = f"chart_type = '{chart_type}'" |
|
chart_eval_code = "\n".join([python_code, y_ref_str, chart_type_str, eval_code]) |
|
if python_code == '': |
|
return '', '' |
|
return python_code, chart_eval_code |
|
|
|
def parse_code_then_exec(prediction): |
|
ecr_1 = False |
|
python_code = parse_python_code(prediction) |
|
if python_code == "": |
|
print("raw_prediction:", prediction) |
|
python_code = ensure_last_line_print(python_code) |
|
python_code = CODE_PREFIX + python_code |
|
python_code = sanitize_input(python_code) |
|
df = pd.read_csv("table.csv") |
|
exec_tool = get_tool(df) |
|
try: |
|
with timeout(10): |
|
observe = exec_tool.run(python_code) |
|
|
|
|
|
|
|
if isinstance(observe, pd.DataFrame): |
|
observe = observe.head().to_markdown(index=False) |
|
else: |
|
observe = str(observe) |
|
ecr_1 = True |
|
except Exception as e: |
|
observe = e |
|
if observe != "": |
|
observe = observe.strip() |
|
|
|
|
|
return observe, ecr_1 |
|
|
|
def execution_eval(observe: str) -> bool: |
|
""" |
|
Test whether the code generated by eval_llm can be executed. |
|
:param output: output code of llm generation |
|
:return: True or False |
|
""" |
|
if observe == "": |
|
return False |
|
|
|
pattern = re.compile(r"error|exception", re.IGNORECASE) |
|
try: |
|
res = not pattern.search(observe) |
|
except: |
|
res = True |
|
return res |
|
|
|
def parse_chart_code_then_exec(sample): |
|
ecr_1 = False |
|
python_code, chart_eval_code = build_chart_eval_code(sample) |
|
df = pd.read_csv("table.csv") |
|
python_code = sanitize_input(python_code) |
|
chart_eval_code = sanitize_input(chart_eval_code) |
|
exec_tool = get_tool(df) |
|
try: |
|
with timeout(10): |
|
_ = exec_tool.run(python_code) |
|
ecr_1 = True |
|
except Exception as e: |
|
pass |
|
try: |
|
with timeout(10): |
|
|
|
observe = exec_tool.run(chart_eval_code) |
|
print("Observe:", observe) |
|
|
|
|
|
if isinstance(observe, pd.DataFrame): |
|
observe = observe.head().to_markdown(index=False) |
|
else: |
|
observe = str(observe) |
|
except Exception as e: |
|
observe = str(e) |
|
observe = observe.strip() |
|
plt.close("all") |
|
return observe, ecr_1 |
|
|