qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
6.71 kB
import re
import ast
from contextlib import redirect_stdout
from io import StringIO
from langchain_experimental.tools.python.tool import PythonAstREPLTool
from typing import Optional
from langchain_core.callbacks.manager import CallbackManagerForToolRun
extra_functions = """
import numpy as np
def compare(list1, list2):
# sort the list
list1.sort()
list2.sort()
if len(list1) != len(list2):
return False
for i in range(len(list1)):
if np.isnan(list1[i]):
if not np.isnan(list2[i]):
return False
elif list1[i] != list2[i]:
return False
return True
def std_digit(list_nums):
new_list = []
for i in range(len(list_nums)):
new_list.append(round(list_nums[i], 2))
return new_list
def compute_general_chart_metric(references, predictions):
processed_references = []
processed_predictions = []
for reference in references:
if isinstance(reference, list):
processed_references.extend(reference)
else:
processed_references.append(reference)
for prediction in predictions:
if isinstance(prediction, list):
processed_predictions.extend(prediction)
else:
processed_predictions.append(prediction)
processed_references = std_digit(processed_references)
processed_predictions = std_digit(processed_predictions)
return compare(processed_references, processed_predictions)
def compute_pie_chart_metric(references, predictions):
processed_references = []
processed_predictions = []
for reference in references:
if isinstance(reference, list):
processed_references.extend(reference)
else:
processed_references.append(reference)
references = processed_references
processed_references = []
total = 0
for reference in references:
total += reference
for reference in references:
processed_references.append(round(reference / total, 2))
for prediction in predictions:
if isinstance(prediction, list):
processed_predictions.extend(prediction)
else:
processed_predictions.append(prediction)
processed_references = std_digit(processed_references)
processed_predictions = std_digit(processed_predictions)
return compare(processed_references, processed_predictions)
def get_line_y_predictions(plt):
line_y_predctions = []
lines = plt.gca().get_lines()
line_y_predctions = [list(line.get_ydata()) for line in lines]
return line_y_predctions
def get_bar_y_predictions(plt):
bar_y_predctions = []
patches = plt.gca().patches
bar_y_predctions = [patch.get_height() for patch in patches]
return bar_y_predctions
def get_hbar_y_predictions(plt):
hbar_y_predctions = []
patches = plt.gca().patches
hbar_y_predctions = [patch.get_width() for patch in patches]
return hbar_y_predctions
def get_pie_y_predictions(plt):
pie_y_predctions = []
patches = plt.gca().patches
for patch in patches:
theta1, theta2 = patch.theta1, patch.theta2
value = round((theta2 - theta1) / 360.0, 2)
pie_y_predctions.append(value)
return pie_y_predctions
def get_area_y_predictions(plt):
area_y_predctions = []
area_collections = plt.gca().collections
for area_collection in area_collections:
area_items = []
for item in area_collection.get_paths()[0].vertices[:, 1]:
if item != 0:
area_items.append(item)
area_y_predctions.append(area_items)
return list(area_y_predctions)
def get_radar_y_predictions(plt):
radar_y_predctions = []
radar_lines = plt.gca().get_lines()
radar_y_predctions = [list(line.get_ydata()) for line in radar_lines]
for i in range(len(radar_y_predctions)):
radar_y_predctions[i] = radar_y_predctions[i][:-1]
return radar_y_predctions
def get_scatter_y_predictions(plt):
scatter_y_predctions = []
scatter_collections = plt.gca().collections
for scatter_collection in scatter_collections:
scatter_items = []
for item in scatter_collection.get_offsets():
scatter_items.append(item[1])
scatter_y_predctions.append(scatter_items)
return scatter_y_predctions
def get_waterfall_y_predictions(plt):
waterfall_y_predctions = []
patches = plt.gca().patches
waterfall_y_predctions = [patch.get_height() for patch in patches]
return waterfall_y_predctions
"""
def sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
query = "\n".join([extra_functions, query])
return query
class CustomPythonTool(PythonAstREPLTool):
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
try:
# 可选的输入处理
if self.sanitize_input:
query = sanitize_input(query)
# 解析 AST 树
tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[])
# 创建一个缓冲区来捕获print的输出
io_buffer = StringIO()
try:
# 捕获执行期间的所有标准输出
with redirect_stdout(io_buffer):
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
module_end = ast.Module(tree.body[-1:], type_ignores=[])
module_end_str = ast.unparse(module_end) # type: ignore
ret = eval(module_end_str, self.globals, self.locals)
# 如果返回值是 None,返回捕获的输出;否则返回结果
if ret is None:
return io_buffer.getvalue()
else:
return io_buffer.getvalue() + str(ret) # 同时返回输出和结果
except Exception:
with redirect_stdout(io_buffer):
exec(module_end_str, self.globals, self.locals)
return io_buffer.getvalue()
except Exception as e:
return "{}: {}".format(type(e).__name__, str(e))