qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
6.37 kB
import os
import sys
import json
import argparse
import sqlite3
from tqdm import tqdm
from joblib import Parallel, delayed
from func_timeout import func_timeout, FunctionTimedOut
def load_json(dir):
with open(dir, 'r') as j:
contents = json.loads(j.read())
return contents
def result_callback(result):
exec_result.append(result)
def execute_sql(predicted_sql,ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
res_dict = {"res": res, "predicted_res": list(set(predicted_res)), "ground_truth_res": list(set(ground_truth_res))}
return res_dict
def execute_model(sql_pair, db_place, idx, meta_time_out):
predicted_sql,ground_truth = sql_pair
try:
res_dict = func_timeout(meta_time_out, execute_sql,
args=(predicted_sql, ground_truth, db_place))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f'timeout',)]
res_dict = {"res": 0, "exec_detail": "timeout"}
except Exception as e:
result = [(f'error',)] # possibly len(query) > 512 or not executable
res_dict = {"res": 0, "exec_detail": "error"}
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {'sql_idx': idx, 'res': res_dict["res"], "detail": res_dict}
# print(result)
return result
def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
clean_sqls = []
db_path_list = []
if mode == 'gpt':
sql_data = json.load(open(sql_path, 'r'))
for idx, sql_str in sql_data.items():
if type(sql_str) == str:
sql, db_name = sql_str.split('\t----- bird -----\t')
else:
sql, db_name = " ", "financial"
clean_sqls.append(sql)
db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite"))
elif mode == 'gt':
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split('\t')
clean_sqls.append(sql)
db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite"))
return clean_sqls, db_path_list
def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
if num_cpus > 1:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
results = Parallel(n_jobs=num_cpus)(
delayed(execute_model)(sqls[i], db_places[i], i, meta_time_out) for i in tqdm(range(len(sqls)), desc="exec"))
return results
def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
def compute_acc_by_diff(exec_results, contents):
num_queries = len(exec_results)
results = [res['res'] for res in exec_results]
simple_results, moderate_results, challenging_results = [], [], []
for i,content in enumerate(contents):
if i >= len(exec_results):
continue
if content['difficulty'] == 'simple':
simple_results.append(exec_results[i])
if content['difficulty'] == 'moderate':
moderate_results.append(exec_results[i])
if content['difficulty'] == 'challenging':
challenging_results.append(exec_results[i])
if len(simple_results) != 0:
simple_acc = sum([res['res'] for res in simple_results])/len(simple_results)
else:
simple_acc = 0
if len(moderate_results) != 0:
moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results)
else:
moderate_acc = 0
if len(challenging_results) != 0:
challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results)
else:
challenging_acc = 0
all_acc = sum(results)/num_queries
count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists
def print_data(score_lists,count_lists):
print('====================================== ACCURACY =====================================')
levels = ['simple', 'moderate', 'challenging', 'total']
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))
def evaluation_main(args, eval_datas, predicted_sql_path):
exec_result = []
pred_queries, db_paths = package_sqls(predicted_sql_path, args.db_root_path, mode='gpt',
data_mode=args.mode)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt',
data_mode=args.mode)
query_pairs = list(zip(pred_queries, gt_queries))
exec_result = run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
exec_result = sort_results(exec_result)
# save_result
res = []
for sql_pair, exec_res, data in zip(query_pairs, exec_result, eval_datas):
predicted_sql, ground_truth = sql_pair
exec_res["ground_truth"] = ground_truth
exec_res["predicted_sql"] = predicted_sql
exec_res["question"] = data["question"]
exec_res["difficulty"] = data["difficulty"]
res.append(exec_res)
output_path = predicted_sql_path.replace(".json", "_exec.json")
json.dump(res, open(output_path, 'w'), indent=4)
print('start calculate')
simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
compute_acc_by_diff(exec_result, eval_datas)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists,count_lists)
print('===========================================================================================')
print("Finished evaluation")