File size: 6,369 Bytes
2a26d3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import sys
import json
import argparse
import sqlite3
from tqdm import tqdm
from joblib import Parallel, delayed
from func_timeout import func_timeout, FunctionTimedOut
def load_json(dir):
with open(dir, 'r') as j:
contents = json.loads(j.read())
return contents
def result_callback(result):
exec_result.append(result)
def execute_sql(predicted_sql,ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
res_dict = {"res": res, "predicted_res": list(set(predicted_res)), "ground_truth_res": list(set(ground_truth_res))}
return res_dict
def execute_model(sql_pair, db_place, idx, meta_time_out):
predicted_sql,ground_truth = sql_pair
try:
res_dict = func_timeout(meta_time_out, execute_sql,
args=(predicted_sql, ground_truth, db_place))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f'timeout',)]
res_dict = {"res": 0, "exec_detail": "timeout"}
except Exception as e:
result = [(f'error',)] # possibly len(query) > 512 or not executable
res_dict = {"res": 0, "exec_detail": "error"}
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {'sql_idx': idx, 'res': res_dict["res"], "detail": res_dict}
# print(result)
return result
def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
clean_sqls = []
db_path_list = []
if mode == 'gpt':
sql_data = json.load(open(sql_path, 'r'))
for idx, sql_str in sql_data.items():
if type(sql_str) == str:
sql, db_name = sql_str.split('\t----- bird -----\t')
else:
sql, db_name = " ", "financial"
clean_sqls.append(sql)
db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite"))
elif mode == 'gt':
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split('\t')
clean_sqls.append(sql)
db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite"))
return clean_sqls, db_path_list
def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
if num_cpus > 1:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
results = Parallel(n_jobs=num_cpus)(
delayed(execute_model)(sqls[i], db_places[i], i, meta_time_out) for i in tqdm(range(len(sqls)), desc="exec"))
return results
def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
def compute_acc_by_diff(exec_results, contents):
num_queries = len(exec_results)
results = [res['res'] for res in exec_results]
simple_results, moderate_results, challenging_results = [], [], []
for i,content in enumerate(contents):
if i >= len(exec_results):
continue
if content['difficulty'] == 'simple':
simple_results.append(exec_results[i])
if content['difficulty'] == 'moderate':
moderate_results.append(exec_results[i])
if content['difficulty'] == 'challenging':
challenging_results.append(exec_results[i])
if len(simple_results) != 0:
simple_acc = sum([res['res'] for res in simple_results])/len(simple_results)
else:
simple_acc = 0
if len(moderate_results) != 0:
moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results)
else:
moderate_acc = 0
if len(challenging_results) != 0:
challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results)
else:
challenging_acc = 0
all_acc = sum(results)/num_queries
count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists
def print_data(score_lists,count_lists):
print('====================================== ACCURACY =====================================')
levels = ['simple', 'moderate', 'challenging', 'total']
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))
def evaluation_main(args, eval_datas, predicted_sql_path):
exec_result = []
pred_queries, db_paths = package_sqls(predicted_sql_path, args.db_root_path, mode='gpt',
data_mode=args.mode)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt',
data_mode=args.mode)
query_pairs = list(zip(pred_queries, gt_queries))
exec_result = run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
exec_result = sort_results(exec_result)
# save_result
res = []
for sql_pair, exec_res, data in zip(query_pairs, exec_result, eval_datas):
predicted_sql, ground_truth = sql_pair
exec_res["ground_truth"] = ground_truth
exec_res["predicted_sql"] = predicted_sql
exec_res["question"] = data["question"]
exec_res["difficulty"] = data["difficulty"]
res.append(exec_res)
output_path = predicted_sql_path.replace(".json", "_exec.json")
json.dump(res, open(output_path, 'w'), indent=4)
print('start calculate')
simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
compute_acc_by_diff(exec_result, eval_datas)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists,count_lists)
print('===========================================================================================')
print("Finished evaluation")
|