|
import re |
|
import string |
|
import unicodedata |
|
|
|
from collections import Counter |
|
from dsp.utils.utils import print_message |
|
|
|
|
|
def EM(prediction, answers_list): |
|
assert type(answers_list) == list |
|
|
|
return max(em_score(prediction, ans) for ans in answers_list) |
|
|
|
|
|
def F1(prediction, answers_list): |
|
assert type(answers_list) == list |
|
|
|
return max(f1_score(prediction, ans) for ans in answers_list) |
|
|
|
|
|
def HotPotF1(prediction, answers_list): |
|
assert type(answers_list) == list |
|
|
|
return max(hotpot_f1_score(prediction, ans) for ans in answers_list) |
|
|
|
|
|
def nF1(history, prediction, answers_list, return_recall=False): |
|
assert type(answers_list) == list |
|
|
|
return max(novel_f1_score(history, prediction, ans, return_recall=return_recall) for ans in answers_list) |
|
|
|
|
|
def normalize_text(s): |
|
s = unicodedata.normalize('NFD', s) |
|
|
|
def remove_articles(text): |
|
return re.sub(r'\b(a|an|the)\b', ' ', text) |
|
|
|
def white_space_fix(text): |
|
return ' '.join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return ''.join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def em_score(prediction, ground_truth): |
|
return normalize_text(prediction) == normalize_text(ground_truth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def f1_score(prediction, ground_truth): |
|
prediction_tokens = normalize_text(prediction).split() |
|
ground_truth_tokens = normalize_text(ground_truth).split() |
|
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
num_same = sum(common.values()) |
|
|
|
if len(prediction_tokens) == len(ground_truth_tokens) == 0: |
|
|
|
print_message( |
|
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n") |
|
|
|
if num_same == 0: |
|
return 0 |
|
|
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
recall = 1.0 * num_same / len(ground_truth_tokens) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
|
|
return f1 |
|
|
|
|
|
def hotpot_f1_score(prediction, ground_truth): |
|
normalized_prediction = normalize_text(prediction) |
|
normalized_ground_truth = normalize_text(ground_truth) |
|
|
|
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
|
return 0 |
|
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
|
return 0 |
|
|
|
prediction_tokens = normalized_prediction.split() |
|
ground_truth_tokens = normalized_ground_truth.split() |
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
num_same = sum(common.values()) |
|
if num_same == 0: |
|
return 0 |
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
recall = 1.0 * num_same / len(ground_truth_tokens) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
return f1 |
|
|
|
|
|
def precision_score(prediction, ground_truth): |
|
prediction_tokens = normalize_text(prediction).split() |
|
ground_truth_tokens = normalize_text(ground_truth).split() |
|
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
num_same = sum(common.values()) |
|
|
|
if len(prediction_tokens) == len(ground_truth_tokens) == 0: |
|
|
|
print_message( |
|
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n") |
|
|
|
if num_same == 0: |
|
return 0 |
|
|
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
|
|
return precision |
|
|
|
|
|
|
|
stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", |
|
"yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", |
|
"they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", |
|
"those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", |
|
"does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", |
|
"of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", |
|
"after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", |
|
"further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", |
|
"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", |
|
"too", "very", "s", "t", "can", "will", "just", "don", "should", "now"] |
|
|
|
|
|
def novel_f1_score(history, prediction, ground_truth, return_recall=False): |
|
history_tokens = normalize_text(history).split() |
|
prediction_tokens = normalize_text(prediction).split() |
|
ground_truth_tokens = normalize_text(ground_truth).split() |
|
|
|
history_tokens = set(history_tokens + stopwords) |
|
|
|
prediction_tokens = [ |
|
t for t in prediction_tokens if t not in history_tokens] |
|
ground_truth_tokens = [ |
|
t for t in ground_truth_tokens if t not in history_tokens] |
|
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
num_same = sum(common.values()) |
|
if num_same == 0: |
|
return 0 |
|
|
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
recall = 1.0 * num_same / len(ground_truth_tokens) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
|
|
if return_recall: |
|
return recall |
|
|
|
return f1 |
|
|