EureCA / dsp /utils /metrics.py
tonneli's picture
Delete history
f5776d3
import re
import string
import unicodedata
from collections import Counter
from dsp.utils.utils import print_message
def EM(prediction, answers_list):
assert type(answers_list) == list
return max(em_score(prediction, ans) for ans in answers_list)
def F1(prediction, answers_list):
assert type(answers_list) == list
return max(f1_score(prediction, ans) for ans in answers_list)
def HotPotF1(prediction, answers_list):
assert type(answers_list) == list
return max(hotpot_f1_score(prediction, ans) for ans in answers_list)
def nF1(history, prediction, answers_list, return_recall=False):
assert type(answers_list) == list
return max(novel_f1_score(history, prediction, ans, return_recall=return_recall) for ans in answers_list)
def normalize_text(s):
s = unicodedata.normalize('NFD', s)
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def em_score(prediction, ground_truth):
return normalize_text(prediction) == normalize_text(ground_truth)
# See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
# See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script
# See: QReCC's
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print_message(
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def hotpot_f1_score(prediction, ground_truth):
normalized_prediction = normalize_text(prediction)
normalized_ground_truth = normalize_text(ground_truth)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def precision_score(prediction, ground_truth):
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print_message(
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
return precision
# Source: https://gist.github.com/sebleier/554280
stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
"yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself",
"they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these",
"those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do",
"does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while",
"of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before",
"after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
"further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each",
"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than",
"too", "very", "s", "t", "can", "will", "just", "don", "should", "now"]
def novel_f1_score(history, prediction, ground_truth, return_recall=False):
history_tokens = normalize_text(history).split()
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
history_tokens = set(history_tokens + stopwords)
prediction_tokens = [
t for t in prediction_tokens if t not in history_tokens]
ground_truth_tokens = [
t for t in ground_truth_tokens if t not in history_tokens]
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
if return_recall:
return recall
return f1