File size: 5,933 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import re
import string
import unicodedata
from collections import Counter
from dsp.utils.utils import print_message
def EM(prediction, answers_list):
assert type(answers_list) == list
return max(em_score(prediction, ans) for ans in answers_list)
def F1(prediction, answers_list):
assert type(answers_list) == list
return max(f1_score(prediction, ans) for ans in answers_list)
def HotPotF1(prediction, answers_list):
assert type(answers_list) == list
return max(hotpot_f1_score(prediction, ans) for ans in answers_list)
def nF1(history, prediction, answers_list, return_recall=False):
assert type(answers_list) == list
return max(novel_f1_score(history, prediction, ans, return_recall=return_recall) for ans in answers_list)
def normalize_text(s):
s = unicodedata.normalize('NFD', s)
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def em_score(prediction, ground_truth):
return normalize_text(prediction) == normalize_text(ground_truth)
# See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
# See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script
# See: QReCC's
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print_message(
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def hotpot_f1_score(prediction, ground_truth):
normalized_prediction = normalize_text(prediction)
normalized_ground_truth = normalize_text(ground_truth)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def precision_score(prediction, ground_truth):
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print_message(
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
return precision
# Source: https://gist.github.com/sebleier/554280
stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
"yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself",
"they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these",
"those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do",
"does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while",
"of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before",
"after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
"further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each",
"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than",
"too", "very", "s", "t", "can", "will", "just", "don", "should", "now"]
def novel_f1_score(history, prediction, ground_truth, return_recall=False):
history_tokens = normalize_text(history).split()
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()
history_tokens = set(history_tokens + stopwords)
prediction_tokens = [
t for t in prediction_tokens if t not in history_tokens]
ground_truth_tokens = [
t for t in ground_truth_tokens if t not in history_tokens]
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
if return_recall:
return recall
return f1
|