|
import os |
|
import numpy as np |
|
import jsonlines |
|
from collections import defaultdict |
|
from sklearn.metrics import classification_report |
|
|
|
|
|
def get_links(sample_string, sample_index): |
|
""" |
|
takes a sample string and returns a list of attach tuples |
|
and a list of rel type strings |
|
""" |
|
|
|
labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN', |
|
'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ'] |
|
|
|
split_list = [st.strip() for st in sample_string.split(' ')] |
|
|
|
rel_list = [] |
|
attach_list = [] |
|
bad = 0 |
|
good = 0 |
|
for a in split_list: |
|
s_tuple = None |
|
rel = None |
|
try: |
|
s = a.split('(')[1].split(')')[0].split(',') |
|
r = a.split('(')[0].strip() |
|
except IndexError: |
|
print('split error at ', sample_index) |
|
else: |
|
try: |
|
s_tuple = (int(s[0]), int(s[1])) |
|
except IndexError: |
|
print('split error at ', sample_index) |
|
except ValueError: |
|
print('value error at ', sample_index) |
|
if r in labels: |
|
|
|
rel = r |
|
|
|
if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: |
|
|
|
attach_list.append((int(s[0]), int(s[1]))) |
|
rel_list.append(r) |
|
good += 1 |
|
else: |
|
bad += 1 |
|
|
|
|
|
|
|
|
|
full_list = [] |
|
endpoints = [] |
|
for i, r in enumerate(attach_list): |
|
if r not in endpoints: |
|
endpoints.append(r) |
|
full_list.append((rel_list[i], r[0], r[1])) |
|
return endpoints, full_list, [good, bad] |
|
|
|
|
|
current_folder=os.getcwd() |
|
|
|
gold_path = '/path/to/jsonl' |
|
pred_path = '/path/to/llamipa_output.txt' |
|
save_results = '/path/to/eval_.txt' |
|
|
|
|
|
with open(pred_path, 'r') as txt: |
|
text = txt.read().split('\n') |
|
|
|
pred_outputs = [] |
|
|
|
for t in text: |
|
if t.startswith(' ### DS:'): |
|
sample = t.split('### DS:')[1].strip() |
|
pred_outputs.append(sample) |
|
print(len(pred_outputs)) |
|
|
|
|
|
gold_outputs = [] |
|
|
|
with jsonlines.open(gold_path) as reader: |
|
for obj in reader: |
|
if not obj['sample'].startswith('NEW DIALOGUE'): |
|
gold_outputs.append(obj['PS']) |
|
|
|
att_f1_l = [] |
|
att_prec_l = [] |
|
att_rec_l = [] |
|
|
|
total_attach_tp = 0 |
|
total_attach_fp = 0 |
|
total_attach_fn = 0 |
|
|
|
type_f1_l = [] |
|
type_prec_l = [] |
|
type_rec_l = [] |
|
|
|
total_TP = [] |
|
|
|
matrix_list = [] |
|
bad_output = 0 |
|
good_output = 0 |
|
|
|
for i, s in enumerate(pred_outputs): |
|
|
|
pred_att, pred_all, malform = get_links(s, i) |
|
gold_att, gold_all, malform = get_links(gold_outputs[i], i) |
|
|
|
bad_output += malform[1] |
|
good_output += malform[0] |
|
|
|
|
|
common = len(set(pred_att).intersection(set(gold_att))) |
|
expected_nulls = (len(pred_att) - common) + (len(gold_att) - common) |
|
|
|
|
|
|
|
if len(gold_att) > 0 and len(pred_att) > 0: |
|
prec = len([e for e in pred_att if e in gold_att])/len(pred_att) |
|
rec = len([e for e in pred_att if e in gold_att])/len(gold_att) |
|
total_attach_tp += len([e for e in pred_att if e in gold_att]) |
|
total_attach_fp += len([e for e in pred_att if e not in gold_att]) |
|
total_attach_fn += len([e for e in gold_att if e not in pred_att]) |
|
else: |
|
prec = 0 |
|
rec = 0 |
|
att_prec_l.append(prec) |
|
att_rec_l.append(rec) |
|
if prec+rec==0: |
|
att_f1_l.append(0) |
|
else: |
|
att_f1_l.append(2*prec*rec/(prec+rec)) |
|
|
|
|
|
if len(gold_all) > 0 and len(pred_all) > 0: |
|
prec = len([e for e in pred_all if e in gold_all])/len(pred_all) |
|
rec = len([e for e in pred_all if e in gold_all])/len(gold_all) |
|
else: |
|
prec = 0 |
|
rec = 0 |
|
type_prec_l.append(prec) |
|
type_rec_l.append(rec) |
|
if prec+rec==0: |
|
type_f1_l.append(0) |
|
else: |
|
type_f1_l.append(2*prec*rec/(prec+rec)) |
|
|
|
|
|
TP = [e for e in pred_all if e in gold_all] |
|
leftover_pred = [p for p in pred_all if p not in TP] |
|
leftover_gold = [p for p in gold_all if p not in TP] |
|
|
|
|
|
total_TP.extend(TP) |
|
|
|
rem_dict = defaultdict(list) |
|
for x in TP: |
|
matrix_list.append([x[0], x[0]]) |
|
for x in leftover_pred: |
|
rem_dict[(x[1], x[2])].append(('p', x[0])) |
|
for x in leftover_gold: |
|
rem_dict[(x[1], x[2])].append(('g', x[0])) |
|
|
|
p_count = 0 |
|
g_count = 0 |
|
null_count = 0 |
|
for k in rem_dict.keys(): |
|
p = 'NULL' |
|
t = 'NULL' |
|
for re in rem_dict[k]: |
|
if re[0] == 'p': |
|
p = re[1] |
|
p_count += 1 |
|
elif re[0] == 'g': |
|
t = re[1] |
|
g_count += 1 |
|
matrix_list.append([t,p]) |
|
if 'NULL' in [t,p]: |
|
null_count += 1 |
|
|
|
assert(len(TP) + p_count == len(pred_all)) |
|
assert(len(TP) + g_count == len(gold_all)) |
|
assert null_count == expected_nulls |
|
|
|
|
|
gold = [m[0] for m in matrix_list] |
|
pred = [m[1] for m in matrix_list] |
|
gold.extend(pred) |
|
labels = list(set(gold)) |
|
|
|
microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn)) |
|
|
|
gold_list = [labels.index(m[0]) for m in matrix_list] |
|
pred_list = [labels.index(m[1]) for m in matrix_list] |
|
|
|
f = open(save_results,"w") |
|
print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f) |
|
print("Attachment Average Precision:",np.mean(att_prec_l), file=f) |
|
print("Attachment Average Recall:",np.mean(att_rec_l), file=f) |
|
print('Micro F1: ', microf1, file=f) |
|
print('--------------------------------', file=f) |
|
print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l)) |
|
print("Attachment + Rel Average Precision:",np.mean(type_prec_l)) |
|
print("Attachment + Rel Average Recall:",np.mean(type_rec_l)) |
|
print('---------------------------------------') |
|
print(classification_report(gold_list,pred_list,target_names=labels), file=f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True) |
|
prec = 0 |
|
rec = 0 |
|
f1 = 0 |
|
count = 0 |
|
|
|
for label in labels: |
|
if label!="NULL": |
|
prec+=d[label]["precision"]*d[label]["support"] |
|
rec+=d[label]["recall"]*d[label]["support"] |
|
f1+=d[label]["f1-score"]*d[label]["support"] |
|
count+=d[label]["support"] |
|
|
|
|
|
|
|
print('--------------------------------', file=f) |
|
print("Weighted Average Precision:", prec/count, file=f) |
|
print("Weighted Average Recall:", rec/count, file=f) |
|
print("Weighted Average F1 score:", f1/count, file=f) |
|
|
|
f.close() |
|
|
|
|
|
|