import re import string import unicodedata from collections import Counter from dsp.utils.utils import print_message def EM(prediction, answers_list): assert type(answers_list) == list return max(em_score(prediction, ans) for ans in answers_list) def F1(prediction, answers_list): assert type(answers_list) == list return max(f1_score(prediction, ans) for ans in answers_list) def HotPotF1(prediction, answers_list): assert type(answers_list) == list return max(hotpot_f1_score(prediction, ans) for ans in answers_list) def nF1(history, prediction, answers_list, return_recall=False): assert type(answers_list) == list return max(novel_f1_score(history, prediction, ans, return_recall=return_recall) for ans in answers_list) def normalize_text(s): s = unicodedata.normalize('NFD', s) def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def em_score(prediction, ground_truth): return normalize_text(prediction) == normalize_text(ground_truth) # See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py # See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script # See: QReCC's def f1_score(prediction, ground_truth): prediction_tokens = normalize_text(prediction).split() ground_truth_tokens = normalize_text(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if len(prediction_tokens) == len(ground_truth_tokens) == 0: # Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity. print_message( "\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n") if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def hotpot_f1_score(prediction, ground_truth): normalized_prediction = normalize_text(prediction) normalized_ground_truth = normalize_text(ground_truth) if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: return 0 if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: return 0 prediction_tokens = normalized_prediction.split() ground_truth_tokens = normalized_ground_truth.split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def precision_score(prediction, ground_truth): prediction_tokens = normalize_text(prediction).split() ground_truth_tokens = normalize_text(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if len(prediction_tokens) == len(ground_truth_tokens) == 0: # Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity. print_message( "\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n") if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) return precision # Source: https://gist.github.com/sebleier/554280 stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now"] def novel_f1_score(history, prediction, ground_truth, return_recall=False): history_tokens = normalize_text(history).split() prediction_tokens = normalize_text(prediction).split() ground_truth_tokens = normalize_text(ground_truth).split() history_tokens = set(history_tokens + stopwords) prediction_tokens = [ t for t in prediction_tokens if t not in history_tokens] ground_truth_tokens = [ t for t in ground_truth_tokens if t not in history_tokens] common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) if return_recall: return recall return f1