import re from table_instruct.eval.scripts.table_utils import evaluate as table_llama_eval from table_instruct.eval.scripts.metric import * from rouge_score import rouge_scorer import numpy as np import nltk from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import sacrebleu from nltk.translate import meteor_score import time def extract_bracket_content(text): # 使用正则表达式提取由 <> 包裹的内容 pattern = r'<(.*?)>' matches = re.findall(pattern, text) # 如果没有匹配内容,则返回原始字符串 return matches[0] if matches else text def split_string(text): # 使用换行符和逗号进行分割 return [item.strip() for item in re.split(r'[\n,]+', text) if item.strip()] def eval_hitab_ex(data): pred_list = [] gold_list = [] for i in range(len(data)): if len(data[i]["predict"].strip("").split(">, <")) > 1: instance_pred_list = data[i]["predict"].strip("").split(">, <") pred_list.append(instance_pred_list) gold_list.append(data[i]["output"].strip("").split(">, <")) else: pred_list.append(data[i]["predict"].strip("")) gold_list.append(data[i]["output"].strip("")) result=table_llama_eval(gold_list, pred_list) return result def compute_rouge(list1, list2): scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) scores = [] for sent1, sent2 in zip(list1, list2): score = scorer.score(sent1, sent2) scores.append(score) rouge1 = np.mean([score['rouge1'].fmeasure for score in scores]) rouge2 = np.mean([score['rouge2'].fmeasure for score in scores]) rougeL = np.mean([score['rougeL'].fmeasure for score in scores]) return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL} def compute_bleu(list1, list2): bleu_scores = [] smoothie = SmoothingFunction().method4 # 用于平滑处理BLEU分数 for ref, pred in zip(list1, list2): reference = [ref.split()] # BLEU 接受参考文本列表 candidate = pred.split() score = sentence_bleu(reference, candidate, smoothing_function=smoothie) bleu_scores.append(score) bleu_score = np.mean(bleu_scores) return bleu_score def compute_sacrebleu(reference_list, candidate_list): individual_scores = [] for ref, pred in zip(reference_list, candidate_list): # 计算每对句子的 BLEU 分数 score = sacrebleu.sentence_bleu(pred, [ref]) # 参考文本需要是列表形式 individual_scores.append(score.score) # 计算平均分 average_bleu = sum(individual_scores) / len(individual_scores) return average_bleu def compute_meteor(reference_list, candidate_list): individual_scores = [] for ref, pred in zip(reference_list, candidate_list): ref_tokens = ref.split() # 参考句子分词 pred_tokens = pred.split() # 预测句子分词 # 直接传入已分词的列表 score = meteor_score.single_meteor_score(ref_tokens, pred_tokens) individual_scores.append(score) # 计算平均分 average_meteor = sum(individual_scores) / len(individual_scores) return average_meteor def eval_bleu(data): test_examples_answer = [x["output"] for x in data] test_predictions_pred = [x["predict"].strip("") for x in data] predictions = test_predictions_pred references = test_examples_answer #rouge = evaluate.load('rouge') #result_rouge = rouge.compute(predictions=predictions, references=references) result_rouge = compute_rouge(references,predictions) result_bleu = compute_bleu(references,predictions) result_sacrebleu = compute_sacrebleu(references,predictions) # result_meteor = compute_meteor(references,predictions) result = { 'rouge':result_rouge, 'bleu':result_bleu, 'sacrebleu':result_sacrebleu, } return result def eval_ent_link_acc(data): #assert len(data) == 2000 correct_count = 0 multi_candidates_example_count = 0 for i in range(len(data)): candidate_list = data[i]["candidates_entity_desc_list"] ground_truth = data[i]["output"].strip("<>").lower() predict = data[i]["predict"].strip("<>").lower() if ground_truth.lower() in predict.lower(): correct_count += 1 if len(candidate_list) > 1: multi_candidates_example_count += 1 acc=correct_count / len(data) result={ "correct_count":correct_count, "acc":acc } return result def eval_col_pop_map(data): rs = [] recall = [] for i in range(len(data)): ground_truth = data[i]["target"].strip(".") # ground_truth = data[i]["target"].strip(".") pred = data[i]["predict"].strip(".") if "" in pred: end_tok_ix = pred.rfind("") pred = pred[:end_tok_ix] ground_truth_list = ground_truth.split(", ") pred_list = split_string(pred) pred_list = [extract_bracket_content(p) for p in pred_list] for k in range(len(pred_list)): pred_list[k] = pred_list[k].strip("<>") new_pred_list = list(set(pred_list)) new_pred_list.sort(key=pred_list.index) r = [1 if z in ground_truth_list else 0 for z in new_pred_list] ap = average_precision(r) # print("ap:", ap) rs.append(r) recall.append(sum(r) / len(ground_truth_list)) map = mean_average_precision(rs) m_recall = sum(recall) / len(data) if map + m_recall == 0: f1=0 else: f1 = 2 * map * m_recall / (map + m_recall) result={ "mean_average_precision":map, "mean_average_recall":m_recall, "f1":f1 } return result def eval_col_type_f1(data): #rel_ex也用这一套 ground_truth_list = [] pred_list = [] for i in range(len(data)): item = data[i] ground_truth = item["ground_truth"] # pred = item["predict"].strip("").split(",") pred = item["predict"].split("")[0].split(", ") ground_truth_list.append(ground_truth) pred_list.append(pred) total_ground_truth_col_types = 0 total_pred_col_types = 0 joint_items_list = [] for i in range(len(ground_truth_list)): total_ground_truth_col_types += len(ground_truth_list[i]) total_pred_col_types += len(pred_list[i]) # joint_items = [item for item in pred_list[i] if item in ground_truth_list[i]] joint_items = [] for g in ground_truth_list[i]: for p in pred_list[i]: if g.lower() in p.lower(): joint_items_list.append(p) joint_items_list += joint_items # import pdb # pdb.set_trace() gt_entire_col_type = {} for i in range(len(ground_truth_list)): gt = list(set(ground_truth_list[i])) for k in range(len(gt)): if gt[k] not in gt_entire_col_type.keys(): gt_entire_col_type[gt[k]] = 1 else: gt_entire_col_type[gt[k]] += 1 # print(len(gt_entire_col_type.keys())) pd_entire_col_type = {} for i in range(len(pred_list)): pd = list(set(pred_list[i])) for k in range(len(pd)): if pd[k] not in pd_entire_col_type.keys(): pd_entire_col_type[pd[k]] = 1 else: pd_entire_col_type[pd[k]] += 1 # print(len(pd_entire_col_type.keys())) joint_entire_col_type = {} for i in range(len(joint_items_list)): if joint_items_list[i] not in joint_entire_col_type.keys(): joint_entire_col_type[joint_items_list[i]] = 1 else: joint_entire_col_type[joint_items_list[i]] += 1 # print(len(joint_entire_col_type.keys())) precision = len(joint_items_list) / total_pred_col_types recall = len(joint_items_list) / total_ground_truth_col_types if precision + recall==0: f1=0 else: f1 = 2 * precision * recall / (precision + recall) sorted_gt = sorted(gt_entire_col_type.items(), key=lambda x: x[1], reverse=True) result = { "precision": precision, "recall": recall, "f1": f1 } return result def eval_tabfact_acc(data): correct = 0 remove_count = 0 for i in range(len(data)): ground_truth = data[i]["output"] prediction = data[i]["predict"] # if prediction.find(ground_truth) == 0: if ground_truth.lower() in prediction.lower(): correct += 1 if prediction.find("") == 0: remove_count += 1 acc=correct / (len(data) - remove_count) result={ "correct":correct, "accuracy":acc } return result def eval_row_pop_map(data): rs = [] recall = [] ap_list = [] for i in range(len(data)): pred = data[i]["predict"].strip(".") if "" in pred: end_tok_ix = pred.rfind("") pred = pred[:end_tok_ix] ground_truth_list = data[i]["target"] pred_list_tmp = split_string(pred) try: pred_list = [extract_bracket_content(p) for p in pred_list_tmp] except: print(pred_list_tmp) for k in range(len(pred_list)): pred_list[k] = pred_list[k].strip("<>") # add to remove repeated generated item new_pred_list = list(set(pred_list)) new_pred_list.sort(key=pred_list.index) # r = [1 if z in ground_truth_list else 0 for z in pred_list] r = [1 if z in ground_truth_list else 0 for z in new_pred_list] # ap = average_precision(r) ap = row_pop_average_precision(r, ground_truth_list) # print("ap:", ap) ap_list.append(ap) map = sum(ap_list) / len(data) m_recall = sum(recall) / len(data) if map + m_recall == 0: f1 = 0 else: f1 = 2 * map * m_recall / (map + m_recall) # print(data_name, len(data)) # print("mean_average_precision:", map) result = { "mean_average_precision": map, "mean_average_recall": m_recall, "f1": f1 } return result