|
import argparse |
|
import warnings |
|
import os |
|
import pandas as pd |
|
from utils import filter_code, load_json, save_json |
|
from reject_eval.run_eval import contains_independent_no, load_json |
|
from reject_eval.eval_metrics import evaluation |
|
from inference_encoder import inference_with_encoder, format_encoder_tables, read_df_head, build_encoder_table_part_content |
|
from reject_eval.prompt import (eval_instruction, eval_system, |
|
output_content_classify_instruct, |
|
output_content_classify_system) |
|
|
|
|
|
def format_encoder_inputs(test_datas: list[dict]) -> list[list[dict]]: |
|
"""Format inputs to the required messages""" |
|
|
|
format_message_datas = [] |
|
for idx, test_dt in enumerate(test_datas): |
|
query = test_dt["query"] |
|
df_info_str = test_dt["df_info"] |
|
table_paths = test_dt["table_paths"] |
|
table_paths = [os.path.join("table_related_benchmarks", table_path) for table_path in table_paths] |
|
df_names = test_dt["df_names"] |
|
|
|
|
|
|
|
content_msg = build_encoder_table_part_content(df_names, table_paths) |
|
|
|
if len(table_paths) != 1: |
|
raise ValueError("多表情况") |
|
|
|
|
|
format_instruction = eval_instruction.format(df_info=df_info_str, input=query) |
|
format_instruction_list = format_instruction.split(df_info_str) |
|
|
|
format_system = eval_system |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": format_system |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": format_instruction_list[0]}, |
|
{"type": "text", "text": df_info_str}, |
|
*content_msg, |
|
{"type": "text", "text": format_instruction_list[1]}, |
|
], |
|
} |
|
] |
|
format_message_datas.append(messages) |
|
|
|
return format_message_datas |
|
|
|
def eval_encoder_outputs( |
|
model_outputs: list[dict], test_file_path: str, save_path: str = "" |
|
) -> None: |
|
"""Calculate the reject evaluation metric based |
|
on model outputs for binary classification |
|
""" |
|
test_datas = load_json(test_file_path) |
|
|
|
output_texts = model_outputs |
|
processed_data = [] |
|
for idx, test_dt in enumerate(test_datas): |
|
llm_output = output_texts[idx] |
|
|
|
test_dt["llm_output"] = llm_output |
|
code, pure_code = filter_code(llm_output) |
|
if pure_code == "" or contains_independent_no(pure_code): |
|
test_dt["is_reject"] = True |
|
else: |
|
test_dt["is_reject"] = False |
|
|
|
processed_data.append(test_dt) |
|
|
|
|
|
parent_path = os.path.dirname(test_file_path) |
|
if not save_path: |
|
save_path = os.path.join(parent_path, "llm_output_data.json") |
|
ground_truth_path = os.path.join(parent_path, "ground_truth.json") |
|
ground_truth_datas = load_json(ground_truth_path) |
|
for i in range(len(ground_truth_datas)): |
|
processed_data[i]["true_result"] = ground_truth_datas[i]["is_reject"] |
|
|
|
if processed_data[i]["true_result"] == processed_data[i]["is_reject"]: |
|
processed_data[i]["flag"] = True |
|
else: |
|
processed_data[i]["flag"] = False |
|
|
|
save_json(save_path, processed_data) |
|
print(f"评估每条数据的模型输出及结果保存路径:{save_path}") |
|
evaluation(ground_truth_path, save_path) |
|
|
|
def main(args): |
|
warnings.filterwarnings('ignore') |
|
test_path = args.test_path |
|
|
|
test_datas = load_json(test_path) |
|
|
|
format_message_datas = format_encoder_inputs(test_datas) |
|
print("Generating eval answers now..") |
|
model_outputs_text = inference_with_encoder(args, format_message_datas) |
|
print("model_outputs_text", len(model_outputs_text)) |
|
print("Generating answers finished..") |
|
|
|
eval_encoder_outputs(model_outputs_text, test_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="eval reject") |
|
parser.add_argument( |
|
"--gpus_num", type=int, default=1, help="the number of GPUs you want to use." |
|
) |
|
parser.add_argument( |
|
"--temperature", type=float, default=0.01, help="Temperature setting" |
|
) |
|
parser.add_argument( |
|
"--model_path", type=str, required=True, help="Path to the model" |
|
) |
|
parser.add_argument( |
|
"--model_type", |
|
choices=["base_model", "chat_model"], |
|
default="chat_model", |
|
help="Base model or Chat model", |
|
) |
|
parser.add_argument( |
|
"--max_new_tokens", |
|
type=int, |
|
default=1024, |
|
help="Maximum number of output new tokens", |
|
) |
|
parser.add_argument( |
|
"--max_model_len", type=int, default=15000, help="Max model length" |
|
) |
|
parser.add_argument( |
|
"--template", |
|
type=str, |
|
choices=[None, "llama3", "baichuan", "chatglm"], |
|
default=None, |
|
help="The template must be specified if not present in the config file", |
|
) |
|
parser.add_argument( |
|
"--test_path", |
|
type=str, |
|
default="table_related_benchmarks/evalset/reject_test/test_query.json", |
|
help="Test File Path", |
|
) |
|
parser.add_argument( |
|
"--save_path", |
|
type=str, |
|
default="output/result_reject.json", |
|
help="LLM output samples save path", |
|
) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|