Other
English
minecraft
action prediction
Llamipa / evaluation /evaluation.py
Kqte's picture
Upload evaluation.py
65debfe verified
import os
import numpy as np
import jsonlines
from collections import defaultdict
from sklearn.metrics import classification_report
def get_links(sample_string, sample_index):
"""
takes a sample string and returns a list of attach tuples
and a list of rel type strings
"""
#MINECRAFT labels
labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
split_list = [st.strip() for st in sample_string.split(' ')]
rel_list = []
attach_list = []
bad = 0
good = 0
for a in split_list:
s_tuple = None
rel = None
try:
s = a.split('(')[1].split(')')[0].split(',')
r = a.split('(')[0].strip()
except IndexError:
print('split error at ', sample_index)
else:
try:
s_tuple = (int(s[0]), int(s[1]))
except IndexError:
print('split error at ', sample_index)
except ValueError:
print('value error at ', sample_index)
if r in labels:
#make sure the label is well-formed
rel = r
if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: #if using a DISTANCE cutoff
# if rel != None and s_tuple != None: #if not using a DISTANCE cutoff
attach_list.append((int(s[0]), int(s[1])))
rel_list.append(r)
good += 1
else:
bad += 1
#re-construct the full list
#a list of tuples (rel, x, y)
#but don't allow doubles!!
full_list = []
endpoints = []
for i, r in enumerate(attach_list):
if r not in endpoints:
endpoints.append(r)
full_list.append((rel_list[i], r[0], r[1]))
return endpoints, full_list, [good, bad]
current_folder=os.getcwd()
gold_path = '/path/to/jsonl'
pred_path = '/path/to/llamipa_output.txt'
save_results = '/path/to/eval_.txt' #to create
#get predicted
with open(pred_path, 'r') as txt:
text = txt.read().split('\n')
pred_outputs = []
for t in text:
if t.startswith(' ### DS:'):
sample = t.split('### DS:')[1].strip()
pred_outputs.append(sample)
print(len(pred_outputs))
#get gold
gold_outputs = []
with jsonlines.open(gold_path) as reader:
for obj in reader:
if not obj['sample'].startswith('NEW DIALOGUE'): #make sure to ignore incremental formatting
gold_outputs.append(obj['PS'])
att_f1_l = []
att_prec_l = []
att_rec_l = []
total_attach_tp = 0
total_attach_fp = 0
total_attach_fn = 0
type_f1_l = []
type_prec_l = []
type_rec_l = []
total_TP = []
matrix_list = []
bad_output = 0
good_output = 0
for i, s in enumerate(pred_outputs):
pred_att, pred_all, malform = get_links(s, i)
gold_att, gold_all, malform = get_links(gold_outputs[i], i)
bad_output += malform[1]
good_output += malform[0]
#calculate number of nulls there should be -- will use to check null count below
common = len(set(pred_att).intersection(set(gold_att)))
expected_nulls = (len(pred_att) - common) + (len(gold_att) - common)
#calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS
if len(gold_att) > 0 and len(pred_att) > 0:
prec = len([e for e in pred_att if e in gold_att])/len(pred_att)
rec = len([e for e in pred_att if e in gold_att])/len(gold_att)
total_attach_tp += len([e for e in pred_att if e in gold_att])
total_attach_fp += len([e for e in pred_att if e not in gold_att])
total_attach_fn += len([e for e in gold_att if e not in pred_att])
else:
prec = 0
rec = 0
att_prec_l.append(prec)
att_rec_l.append(rec)
if prec+rec==0:
att_f1_l.append(0)
else:
att_f1_l.append(2*prec*rec/(prec+rec))
#calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS+RELATION TYPE
if len(gold_all) > 0 and len(pred_all) > 0:
prec = len([e for e in pred_all if e in gold_all])/len(pred_all)
rec = len([e for e in pred_all if e in gold_all])/len(gold_all)
else:
prec = 0
rec = 0
type_prec_l.append(prec)
type_rec_l.append(rec)
if prec+rec==0:
type_f1_l.append(0)
else:
type_f1_l.append(2*prec*rec/(prec+rec))
#create the relation comparisons by type
TP = [e for e in pred_all if e in gold_all]
leftover_pred = [p for p in pred_all if p not in TP]
leftover_gold = [p for p in gold_all if p not in TP]
#then process the TP, FP, FN for matrix
total_TP.extend(TP)
rem_dict = defaultdict(list)
for x in TP:
matrix_list.append([x[0], x[0]])
for x in leftover_pred:
rem_dict[(x[1], x[2])].append(('p', x[0]))
for x in leftover_gold:
rem_dict[(x[1], x[2])].append(('g', x[0]))
p_count = 0
g_count = 0
null_count = 0
for k in rem_dict.keys():
p = 'NULL'
t = 'NULL'
for re in rem_dict[k]:
if re[0] == 'p':
p = re[1]
p_count += 1
elif re[0] == 'g':
t = re[1]
g_count += 1
matrix_list.append([t,p])
if 'NULL' in [t,p]:
null_count += 1
assert(len(TP) + p_count == len(pred_all))
assert(len(TP) + g_count == len(gold_all))
assert null_count == expected_nulls
#compute labels in gold and pred
gold = [m[0] for m in matrix_list]
pred = [m[1] for m in matrix_list]
gold.extend(pred)
labels = list(set(gold))
microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn))
gold_list = [labels.index(m[0]) for m in matrix_list]
pred_list = [labels.index(m[1]) for m in matrix_list]
f = open(save_results,"w")
print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f)
print("Attachment Average Precision:",np.mean(att_prec_l), file=f)
print("Attachment Average Recall:",np.mean(att_rec_l), file=f)
print('Micro F1: ', microf1, file=f)
print('--------------------------------', file=f)
print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l))
print("Attachment + Rel Average Precision:",np.mean(type_prec_l))
print("Attachment + Rel Average Recall:",np.mean(type_rec_l))
print('---------------------------------------')
print(classification_report(gold_list,pred_list,target_names=labels), file=f)
# The F1-scores for the relation types displayed in the above table are correct.
#That is, while calculating F1 for label l, all the ["NULL", l] entries count towards false-positive for label l
#and all the [l, "NULL"] entries count towards false-negative for label l.
#So, the "NULL" type is affecting the precision/recall/F1 for label l (as it should).
#Now, for the overall weighted average precision/recall/f1-score,
# we want the average to be over the actual relation labels set (i.e. excluding "NULL" class).
#For that, we do this:
d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True)
prec = 0
rec = 0
f1 = 0
count = 0
for label in labels:
if label!="NULL":
prec+=d[label]["precision"]*d[label]["support"]
rec+=d[label]["recall"]*d[label]["support"]
f1+=d[label]["f1-score"]*d[label]["support"]
count+=d[label]["support"]
# checking that support is same as the number of ground truth instance for the label
# assert d[label]["support"] == Counter(g_label_l)[label]
print('--------------------------------', file=f)
print("Weighted Average Precision:", prec/count, file=f)
print("Weighted Average Recall:", rec/count, file=f)
print("Weighted Average F1 score:", f1/count, file=f)
f.close()