Upload metrics.py with huggingface_hub
Browse files- metrics.py +356 -49
metrics.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import logging
|
| 2 |
import re
|
| 3 |
import string
|
| 4 |
import uuid
|
|
@@ -14,6 +13,7 @@ from scipy.stats import bootstrap
|
|
| 14 |
|
| 15 |
from .artifact import Artifact
|
| 16 |
from .dataclass import InternalField, OptionalField
|
|
|
|
| 17 |
from .operator import (
|
| 18 |
MultiStreamOperator,
|
| 19 |
SingleStreamOperator,
|
|
@@ -23,7 +23,9 @@ from .operator import (
|
|
| 23 |
from .operators import CopyFields
|
| 24 |
from .random_utils import get_seed
|
| 25 |
from .stream import MultiStream, Stream
|
|
|
|
| 26 |
|
|
|
|
| 27 |
# The default number of resamples used to estimate the confidence intervals
|
| 28 |
# global and instances metrics. Use None to disable confidence interval computation by default.
|
| 29 |
_N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS = 1000
|
|
@@ -61,6 +63,7 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 61 |
# Use None to disable confidence interval computation.
|
| 62 |
n_resamples: int = None
|
| 63 |
confidence_level: float = 0.95
|
|
|
|
| 64 |
|
| 65 |
@staticmethod
|
| 66 |
def new_random_generator():
|
|
@@ -79,7 +82,7 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 79 |
and num_predictions > 1
|
| 80 |
)
|
| 81 |
|
| 82 |
-
def score_based_confidence_interval(self,
|
| 83 |
"""Compute confidence intervals based on existing scores, already computed on the input instances.
|
| 84 |
|
| 85 |
score_names: List[str]
|
|
@@ -94,6 +97,10 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 94 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
| 95 |
return result
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
for score_name in score_names:
|
| 98 |
scores = [
|
| 99 |
instance["score"]["instance"][score_name] for instance in instances
|
|
@@ -131,7 +138,7 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 131 |
except Exception as e:
|
| 132 |
# this happens in edge cases, for example, when the sampling creates a
|
| 133 |
# sample where all strings are empty and this fails bleu.
|
| 134 |
-
|
| 135 |
return np.nan
|
| 136 |
|
| 137 |
scores = numpy.apply_along_axis(
|
|
@@ -341,7 +348,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 341 |
global_score["score_name"] = self.main_score
|
| 342 |
|
| 343 |
confidence_interval = self.score_based_confidence_interval(
|
| 344 |
-
|
| 345 |
)
|
| 346 |
global_score.update(confidence_interval)
|
| 347 |
|
|
@@ -411,7 +418,7 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 411 |
global_score["score_name"] = self.main_score
|
| 412 |
|
| 413 |
confidence_interval = self.score_based_confidence_interval(
|
| 414 |
-
|
| 415 |
)
|
| 416 |
global_score.update(confidence_interval)
|
| 417 |
|
|
@@ -473,6 +480,23 @@ class Accuracy(InstanceMetric):
|
|
| 473 |
return result
|
| 474 |
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
class MetricPipeline(MultiStreamOperator, Metric):
|
| 477 |
main_score: str = None
|
| 478 |
preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
|
@@ -512,9 +536,29 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 512 |
|
| 513 |
scale: float = 1.0 # optional scaling of main results
|
| 514 |
scaled_fields: list = None
|
|
|
|
| 515 |
hf_compute_args: Dict[str, Any] = OptionalField(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
def prepare(self):
|
| 519 |
super().prepare()
|
| 520 |
self.metric = evaluate.load(
|
|
@@ -527,8 +571,36 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 527 |
predictions: List[Any],
|
| 528 |
additional_inputs: List[Dict],
|
| 529 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
result = self.metric.compute(
|
| 531 |
-
predictions=predictions,
|
|
|
|
|
|
|
|
|
|
| 532 |
)
|
| 533 |
if self.hf_main_score:
|
| 534 |
result[self.main_score] = result[self.hf_main_score]
|
|
@@ -559,6 +631,7 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
| 559 |
|
| 560 |
hf_metric_fields: List[str]
|
| 561 |
hf_compute_args: dict = {}
|
|
|
|
| 562 |
|
| 563 |
def prepare(self):
|
| 564 |
super().prepare()
|
|
@@ -570,8 +643,23 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
| 570 |
predictions: List[str],
|
| 571 |
additional_inputs: List[Any],
|
| 572 |
) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
scores = self.metric.compute(
|
| 574 |
-
predictions=predictions,
|
|
|
|
|
|
|
|
|
|
| 575 |
)
|
| 576 |
|
| 577 |
# convert dict of lists to a list of dicts
|
|
@@ -656,10 +744,11 @@ class F1MultiLabel(GlobalMetric):
|
|
| 656 |
main_score = "f1_macro"
|
| 657 |
average = None # Report per class then aggregate by mean
|
| 658 |
classes_to_ignore = ["none"]
|
|
|
|
| 659 |
|
| 660 |
def prepare(self):
|
| 661 |
super().prepare()
|
| 662 |
-
self._metric = evaluate.load(
|
| 663 |
|
| 664 |
def add_str_to_id(self, str):
|
| 665 |
if str not in self.str_to_id:
|
|
@@ -683,22 +772,10 @@ class F1MultiLabel(GlobalMetric):
|
|
| 683 |
) -> dict:
|
| 684 |
self.str_to_id = {}
|
| 685 |
self.id_to_str = {}
|
| 686 |
-
assert all(
|
| 687 |
-
len(reference) == 1 for reference in references
|
| 688 |
-
), "Only a single reference per prediction is allowed in F1 multi label metric"
|
| 689 |
|
|
|
|
| 690 |
references = [reference[0] for reference in references]
|
| 691 |
|
| 692 |
-
for reference in references:
|
| 693 |
-
assert isinstance(
|
| 694 |
-
references, list
|
| 695 |
-
), f"Each reference is expected to list of strings in F1 multi label metric. Received reference: {reference}"
|
| 696 |
-
|
| 697 |
-
for prediction in predictions:
|
| 698 |
-
assert isinstance(
|
| 699 |
-
prediction, list
|
| 700 |
-
), f"Each prediction is expected to list of strings in F1 multi label metric. Received prediction: {prediction}"
|
| 701 |
-
|
| 702 |
labels = [
|
| 703 |
lbl
|
| 704 |
for lbl in {label for reference in references for label in reference}
|
|
@@ -732,19 +809,60 @@ class F1MultiLabel(GlobalMetric):
|
|
| 732 |
average=self.average,
|
| 733 |
labels=labels_param,
|
| 734 |
)
|
| 735 |
-
if isinstance(result[
|
| 736 |
from statistics import mean
|
| 737 |
|
| 738 |
-
assert
|
| 739 |
-
labels
|
| 740 |
-
), f
|
| 741 |
-
final_result = {self.main_score: mean(result[
|
| 742 |
for i, label in enumerate(labels):
|
| 743 |
-
final_result["
|
| 744 |
else:
|
| 745 |
-
final_result = {self.main_score: result[
|
| 746 |
return final_result
|
| 747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
class F1MicroMultiLabel(F1MultiLabel):
|
| 750 |
main_score = "f1_micro"
|
|
@@ -868,27 +986,36 @@ class MatthewsCorrelation(HuggingfaceMetric):
|
|
| 868 |
|
| 869 |
class CustomF1(GlobalMetric):
|
| 870 |
main_score = "f1_micro"
|
| 871 |
-
|
| 872 |
zero_division = 0.0
|
| 873 |
|
| 874 |
@abstractmethod
|
| 875 |
-
def get_element_group(self, element):
|
| 876 |
pass
|
| 877 |
|
| 878 |
@abstractmethod
|
| 879 |
-
def get_element_representation(self, element):
|
| 880 |
pass
|
| 881 |
|
| 882 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
return {
|
| 884 |
k: Counter(
|
| 885 |
[
|
| 886 |
-
self.get_element_representation(value)
|
| 887 |
for value in elements_list
|
| 888 |
-
if self.get_element_group(value) == k
|
| 889 |
]
|
| 890 |
)
|
| 891 |
-
for k in {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
}
|
| 893 |
|
| 894 |
def calculate_groups_ratio(self, actual_group, total_group):
|
|
@@ -910,30 +1037,46 @@ class CustomF1(GlobalMetric):
|
|
| 910 |
except ZeroDivisionError:
|
| 911 |
return self.zero_division
|
| 912 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
def compute(
|
| 914 |
self,
|
| 915 |
-
references: List[Any],
|
| 916 |
predictions: List[Any],
|
| 917 |
additional_inputs: List[Dict],
|
| 918 |
) -> dict:
|
| 919 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
| 920 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
references = [element[0] for element in references]
|
| 922 |
|
| 923 |
assert len(references) == len(predictions), (
|
| 924 |
f"references size ({len(references)})"
|
| 925 |
f" doesn't mach predictions sise ({len(references)})."
|
| 926 |
)
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
}
|
| 931 |
else:
|
| 932 |
-
|
| 933 |
groups_statistics = {}
|
| 934 |
-
for references_batch, predictions_batch in zip(
|
| 935 |
-
|
| 936 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
all_groups = set(grouped_references.keys()).union(
|
| 938 |
grouped_predictions.keys()
|
| 939 |
)
|
|
@@ -976,7 +1119,7 @@ class CustomF1(GlobalMetric):
|
|
| 976 |
rn_total + rn,
|
| 977 |
rd_total + rd,
|
| 978 |
)
|
| 979 |
-
if group in
|
| 980 |
f1_result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
|
| 981 |
recall_result[f"recall_{group}"] = self.recall(pn, pd, rn, rd)
|
| 982 |
precision_result[f"precision_{group}"] = self.precision(pn, pd, rn, rd)
|
|
@@ -995,7 +1138,7 @@ class CustomF1(GlobalMetric):
|
|
| 995 |
except ZeroDivisionError:
|
| 996 |
result["f1_macro"] = self.zero_division
|
| 997 |
result["recall_macro"] = self.zero_division
|
| 998 |
-
result["
|
| 999 |
|
| 1000 |
amount_of_predictions = pd_total
|
| 1001 |
if amount_of_predictions == 0:
|
|
@@ -1013,10 +1156,10 @@ class CustomF1(GlobalMetric):
|
|
| 1013 |
|
| 1014 |
|
| 1015 |
class NER(CustomF1):
|
| 1016 |
-
def get_element_group(self, element):
|
| 1017 |
return element[1]
|
| 1018 |
|
| 1019 |
-
def get_element_representation(self, element):
|
| 1020 |
return str(element)
|
| 1021 |
|
| 1022 |
|
|
@@ -1042,6 +1185,7 @@ def normalize_answer(s):
|
|
| 1042 |
class TokenOverlap(InstanceMetric):
|
| 1043 |
reduction_map = {"mean": ["f1", "precision", "recall"]}
|
| 1044 |
main_score = "f1"
|
|
|
|
| 1045 |
|
| 1046 |
def compute(
|
| 1047 |
self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
|
|
@@ -1075,6 +1219,7 @@ class BertScore(HuggingfaceBulkMetric):
|
|
| 1075 |
main_score = "f1"
|
| 1076 |
reduction_map = {"mean": ["f1", "precision", "recall"]}
|
| 1077 |
hf_metric_fields = ["f1", "precision", "recall"]
|
|
|
|
| 1078 |
model_name: str
|
| 1079 |
|
| 1080 |
def prepare(self):
|
|
@@ -1223,3 +1368,165 @@ class NDCG(GlobalMetric):
|
|
| 1223 |
]
|
| 1224 |
scores.append(self.eval([q_references], [q_predictions]))
|
| 1225 |
return {self.main_score: mean(scores) if len(scores) > 0 else np.nan}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import string
|
| 3 |
import uuid
|
|
|
|
| 13 |
|
| 14 |
from .artifact import Artifact
|
| 15 |
from .dataclass import InternalField, OptionalField
|
| 16 |
+
from .logging_utils import get_logger
|
| 17 |
from .operator import (
|
| 18 |
MultiStreamOperator,
|
| 19 |
SingleStreamOperator,
|
|
|
|
| 23 |
from .operators import CopyFields
|
| 24 |
from .random_utils import get_seed
|
| 25 |
from .stream import MultiStream, Stream
|
| 26 |
+
from .type_utils import isoftype
|
| 27 |
|
| 28 |
+
logger = get_logger()
|
| 29 |
# The default number of resamples used to estimate the confidence intervals
|
| 30 |
# global and instances metrics. Use None to disable confidence interval computation by default.
|
| 31 |
_N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS = 1000
|
|
|
|
| 63 |
# Use None to disable confidence interval computation.
|
| 64 |
n_resamples: int = None
|
| 65 |
confidence_level: float = 0.95
|
| 66 |
+
ci_scores: List[str] = None
|
| 67 |
|
| 68 |
@staticmethod
|
| 69 |
def new_random_generator():
|
|
|
|
| 82 |
and num_predictions > 1
|
| 83 |
)
|
| 84 |
|
| 85 |
+
def score_based_confidence_interval(self, instances):
|
| 86 |
"""Compute confidence intervals based on existing scores, already computed on the input instances.
|
| 87 |
|
| 88 |
score_names: List[str]
|
|
|
|
| 97 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
| 98 |
return result
|
| 99 |
|
| 100 |
+
score_names = (
|
| 101 |
+
self.ci_scores if self.ci_scores is not None else [self.main_score]
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
for score_name in score_names:
|
| 105 |
scores = [
|
| 106 |
instance["score"]["instance"][score_name] for instance in instances
|
|
|
|
| 138 |
except Exception as e:
|
| 139 |
# this happens in edge cases, for example, when the sampling creates a
|
| 140 |
# sample where all strings are empty and this fails bleu.
|
| 141 |
+
logger.info(f"Warning in {self.__class__.__name__}", e)
|
| 142 |
return np.nan
|
| 143 |
|
| 144 |
scores = numpy.apply_along_axis(
|
|
|
|
| 348 |
global_score["score_name"] = self.main_score
|
| 349 |
|
| 350 |
confidence_interval = self.score_based_confidence_interval(
|
| 351 |
+
instances=instances
|
| 352 |
)
|
| 353 |
global_score.update(confidence_interval)
|
| 354 |
|
|
|
|
| 418 |
global_score["score_name"] = self.main_score
|
| 419 |
|
| 420 |
confidence_interval = self.score_based_confidence_interval(
|
| 421 |
+
instances=instances
|
| 422 |
)
|
| 423 |
global_score.update(confidence_interval)
|
| 424 |
|
|
|
|
| 480 |
return result
|
| 481 |
|
| 482 |
|
| 483 |
+
class StringContainment(InstanceMetric):
|
| 484 |
+
reduction_map = {"mean": ["string_containment"]}
|
| 485 |
+
main_score = "string_containment"
|
| 486 |
+
|
| 487 |
+
def compute(
|
| 488 |
+
self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
|
| 489 |
+
) -> dict:
|
| 490 |
+
result = {
|
| 491 |
+
self.main_score: float(
|
| 492 |
+
any(str(reference) in prediction for reference in references)
|
| 493 |
+
)
|
| 494 |
+
}
|
| 495 |
+
result["score"] = result[self.main_score]
|
| 496 |
+
result["score_name"] = self.main_score
|
| 497 |
+
return result
|
| 498 |
+
|
| 499 |
+
|
| 500 |
class MetricPipeline(MultiStreamOperator, Metric):
|
| 501 |
main_score: str = None
|
| 502 |
preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
|
|
|
| 536 |
|
| 537 |
scale: float = 1.0 # optional scaling of main results
|
| 538 |
scaled_fields: list = None
|
| 539 |
+
# This are fixed arguments passed to compute method
|
| 540 |
hf_compute_args: Dict[str, Any] = OptionalField(default_factory=dict)
|
| 541 |
+
# These are additional input fields passed to HF compute method (a list with one value per instance)
|
| 542 |
+
hf_additional_input_fields: List = OptionalField(default_factory=list)
|
| 543 |
+
# These are additional input fields that are passed as one value
|
| 544 |
+
hf_additional_input_fields_pass_one_value: List = OptionalField(
|
| 545 |
+
default_factory=list
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
|
| 549 |
|
| 550 |
+
def verify(self):
|
| 551 |
+
assert (
|
| 552 |
+
self.hf_additional_input_fields is None
|
| 553 |
+
or isoftype(self.hf_additional_input_fields, List[str])
|
| 554 |
+
), f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
|
| 555 |
+
assert (
|
| 556 |
+
self.hf_additional_input_fields_pass_one_value is None
|
| 557 |
+
or isoftype(self.hf_additional_input_fields_pass_one_value, List[str])
|
| 558 |
+
), f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."
|
| 559 |
+
|
| 560 |
+
return super().verify()
|
| 561 |
+
|
| 562 |
def prepare(self):
|
| 563 |
super().prepare()
|
| 564 |
self.metric = evaluate.load(
|
|
|
|
| 571 |
predictions: List[Any],
|
| 572 |
additional_inputs: List[Dict],
|
| 573 |
) -> dict:
|
| 574 |
+
passed_additional_inputs = {}
|
| 575 |
+
for additional_input_field in self.hf_additional_input_fields:
|
| 576 |
+
assert (
|
| 577 |
+
additional_input_field in additional_inputs[0]
|
| 578 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
|
| 579 |
+
passed_additional_inputs[additional_input_field] = [
|
| 580 |
+
additional_input[additional_input_field]
|
| 581 |
+
for additional_input in additional_inputs
|
| 582 |
+
]
|
| 583 |
+
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
|
| 584 |
+
assert (
|
| 585 |
+
additional_input_field in additional_inputs[0]
|
| 586 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
|
| 587 |
+
|
| 588 |
+
values = {
|
| 589 |
+
additional_input[additional_input_field]
|
| 590 |
+
for additional_input in additional_inputs
|
| 591 |
+
}
|
| 592 |
+
assert (
|
| 593 |
+
len(values) == 1
|
| 594 |
+
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
|
| 595 |
+
|
| 596 |
+
passed_additional_inputs[additional_input_field] = next(iter(values))
|
| 597 |
+
|
| 598 |
+
# add check that all required fields in self.metrics are in passed_additional_inputs print(passed_additional_inputs)
|
| 599 |
result = self.metric.compute(
|
| 600 |
+
predictions=predictions,
|
| 601 |
+
references=references,
|
| 602 |
+
**passed_additional_inputs,
|
| 603 |
+
**self.hf_compute_args,
|
| 604 |
)
|
| 605 |
if self.hf_main_score:
|
| 606 |
result[self.main_score] = result[self.hf_main_score]
|
|
|
|
| 631 |
|
| 632 |
hf_metric_fields: List[str]
|
| 633 |
hf_compute_args: dict = {}
|
| 634 |
+
hf_additional_input_fields: List = OptionalField(default_factory=list)
|
| 635 |
|
| 636 |
def prepare(self):
|
| 637 |
super().prepare()
|
|
|
|
| 643 |
predictions: List[str],
|
| 644 |
additional_inputs: List[Any],
|
| 645 |
) -> List[Dict[str, Any]]:
|
| 646 |
+
passed_additional_inputs = {}
|
| 647 |
+
passed_additional_inputs = {}
|
| 648 |
+
for additional_input_field in self.hf_additional_input_fields:
|
| 649 |
+
assert (
|
| 650 |
+
additional_input_field in additional_inputs[0]
|
| 651 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
|
| 652 |
+
passed_additional_inputs[additional_input_field] = [
|
| 653 |
+
additional_input[additional_input_field]
|
| 654 |
+
for additional_input in additional_inputs
|
| 655 |
+
]
|
| 656 |
+
# add check that all required fields in self.metrics are in passed_additional_inputs
|
| 657 |
+
|
| 658 |
scores = self.metric.compute(
|
| 659 |
+
predictions=predictions,
|
| 660 |
+
references=references,
|
| 661 |
+
**passed_additional_inputs,
|
| 662 |
+
**self.hf_compute_args,
|
| 663 |
)
|
| 664 |
|
| 665 |
# convert dict of lists to a list of dicts
|
|
|
|
| 744 |
main_score = "f1_macro"
|
| 745 |
average = None # Report per class then aggregate by mean
|
| 746 |
classes_to_ignore = ["none"]
|
| 747 |
+
metric = "f1"
|
| 748 |
|
| 749 |
def prepare(self):
|
| 750 |
super().prepare()
|
| 751 |
+
self._metric = evaluate.load(self.metric, "multilabel")
|
| 752 |
|
| 753 |
def add_str_to_id(self, str):
|
| 754 |
if str not in self.str_to_id:
|
|
|
|
| 772 |
) -> dict:
|
| 773 |
self.str_to_id = {}
|
| 774 |
self.id_to_str = {}
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
+
self._validate_references_and_prediction(references, predictions)
|
| 777 |
references = [reference[0] for reference in references]
|
| 778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
labels = [
|
| 780 |
lbl
|
| 781 |
for lbl in {label for reference in references for label in reference}
|
|
|
|
| 809 |
average=self.average,
|
| 810 |
labels=labels_param,
|
| 811 |
)
|
| 812 |
+
if isinstance(result[self.metric], numpy.ndarray):
|
| 813 |
from statistics import mean
|
| 814 |
|
| 815 |
+
assert (
|
| 816 |
+
len(result[self.metric]) == len(labels)
|
| 817 |
+
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
|
| 818 |
+
final_result = {self.main_score: mean(result[self.metric])}
|
| 819 |
for i, label in enumerate(labels):
|
| 820 |
+
final_result[self.metric + "_" + label] = result[self.metric][i]
|
| 821 |
else:
|
| 822 |
+
final_result = {self.main_score: result[self.metric]}
|
| 823 |
return final_result
|
| 824 |
|
| 825 |
+
def _validate_references_and_prediction(self, references, predictions):
|
| 826 |
+
for reference in references:
|
| 827 |
+
if not len(reference) == 1:
|
| 828 |
+
raise ValueError(
|
| 829 |
+
f"Only a single reference per prediction is allowed in F1 multi label metric. Received reference: {reference}"
|
| 830 |
+
)
|
| 831 |
+
if not isoftype(reference[0], List[str]):
|
| 832 |
+
raise ValueError(
|
| 833 |
+
f"Each reference is expected to be a list of strings in F1 multi label metric. Received reference: '{reference[0]}'"
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
for prediction in predictions:
|
| 837 |
+
if not isoftype(prediction, List[str]):
|
| 838 |
+
raise ValueError(
|
| 839 |
+
f"Each prediction is expected to be a list of strings in F1 multi label metric. Received prediction: '{prediction}'"
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
class PrecisionMacroMultiLabel(F1MultiLabel):
|
| 844 |
+
main_score = "precision_macro"
|
| 845 |
+
metric = "precision"
|
| 846 |
+
average = "macro"
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
class PrecisionMicroMultiLabel(F1MultiLabel):
|
| 850 |
+
main_score = "precision_micro"
|
| 851 |
+
metric = "precision"
|
| 852 |
+
average = "micro"
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class RecallMacroMultiLabel(F1MultiLabel):
|
| 856 |
+
main_score = "recall_macro"
|
| 857 |
+
metric = "recall"
|
| 858 |
+
average = "macro"
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
class RecallMicroMultiLabel(F1MultiLabel):
|
| 862 |
+
main_score = "recall_micro"
|
| 863 |
+
metric = "recall"
|
| 864 |
+
average = "micro"
|
| 865 |
+
|
| 866 |
|
| 867 |
class F1MicroMultiLabel(F1MultiLabel):
|
| 868 |
main_score = "f1_micro"
|
|
|
|
| 986 |
|
| 987 |
class CustomF1(GlobalMetric):
|
| 988 |
main_score = "f1_micro"
|
| 989 |
+
groups = None
|
| 990 |
zero_division = 0.0
|
| 991 |
|
| 992 |
@abstractmethod
|
| 993 |
+
def get_element_group(self, element, additional_input):
|
| 994 |
pass
|
| 995 |
|
| 996 |
@abstractmethod
|
| 997 |
+
def get_element_representation(self, element, additional_input):
|
| 998 |
pass
|
| 999 |
|
| 1000 |
+
def should_ignore_element(self, element, additional_input):
|
| 1001 |
+
return False
|
| 1002 |
+
|
| 1003 |
+
def group_elements(self, elements_list, additional_input):
|
| 1004 |
+
if not isinstance(elements_list, list):
|
| 1005 |
+
elements_list = [elements_list]
|
| 1006 |
return {
|
| 1007 |
k: Counter(
|
| 1008 |
[
|
| 1009 |
+
self.get_element_representation(value, additional_input)
|
| 1010 |
for value in elements_list
|
| 1011 |
+
if self.get_element_group(value, additional_input) == k
|
| 1012 |
]
|
| 1013 |
)
|
| 1014 |
+
for k in {
|
| 1015 |
+
self.get_element_group(e, additional_input)
|
| 1016 |
+
for e in elements_list
|
| 1017 |
+
if not self.should_ignore_element(e, additional_input)
|
| 1018 |
+
}
|
| 1019 |
}
|
| 1020 |
|
| 1021 |
def calculate_groups_ratio(self, actual_group, total_group):
|
|
|
|
| 1037 |
except ZeroDivisionError:
|
| 1038 |
return self.zero_division
|
| 1039 |
|
| 1040 |
+
def get_groups(self, elements, additional_inputs):
|
| 1041 |
+
groups = set()
|
| 1042 |
+
for sublist, additional_input in zip(elements, additional_inputs):
|
| 1043 |
+
for e in sublist:
|
| 1044 |
+
if self.should_ignore_element(e, additional_input):
|
| 1045 |
+
continue
|
| 1046 |
+
groups.add(self.get_element_group(e, additional_input))
|
| 1047 |
+
return groups
|
| 1048 |
+
|
| 1049 |
def compute(
|
| 1050 |
self,
|
| 1051 |
+
references: List[List[Any]],
|
| 1052 |
predictions: List[Any],
|
| 1053 |
additional_inputs: List[Dict],
|
| 1054 |
) -> dict:
|
| 1055 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
| 1056 |
+
if (
|
| 1057 |
+
isinstance(references[0], list)
|
| 1058 |
+
and len(references[0]) > 0
|
| 1059 |
+
and isinstance(references[0][0], list)
|
| 1060 |
+
):
|
| 1061 |
references = [element[0] for element in references]
|
| 1062 |
|
| 1063 |
assert len(references) == len(predictions), (
|
| 1064 |
f"references size ({len(references)})"
|
| 1065 |
f" doesn't mach predictions sise ({len(references)})."
|
| 1066 |
)
|
| 1067 |
+
|
| 1068 |
+
if self.groups is None:
|
| 1069 |
+
groups = self.get_groups(references, additional_inputs)
|
|
|
|
| 1070 |
else:
|
| 1071 |
+
groups = self.groups
|
| 1072 |
groups_statistics = {}
|
| 1073 |
+
for references_batch, predictions_batch, additional_input in zip(
|
| 1074 |
+
references, predictions, additional_inputs
|
| 1075 |
+
):
|
| 1076 |
+
grouped_references = self.group_elements(references_batch, additional_input)
|
| 1077 |
+
grouped_predictions = self.group_elements(
|
| 1078 |
+
predictions_batch, additional_input
|
| 1079 |
+
)
|
| 1080 |
all_groups = set(grouped_references.keys()).union(
|
| 1081 |
grouped_predictions.keys()
|
| 1082 |
)
|
|
|
|
| 1119 |
rn_total + rn,
|
| 1120 |
rd_total + rd,
|
| 1121 |
)
|
| 1122 |
+
if group in groups:
|
| 1123 |
f1_result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
|
| 1124 |
recall_result[f"recall_{group}"] = self.recall(pn, pd, rn, rd)
|
| 1125 |
precision_result[f"precision_{group}"] = self.precision(pn, pd, rn, rd)
|
|
|
|
| 1138 |
except ZeroDivisionError:
|
| 1139 |
result["f1_macro"] = self.zero_division
|
| 1140 |
result["recall_macro"] = self.zero_division
|
| 1141 |
+
result["precision_macro"] = self.zero_division
|
| 1142 |
|
| 1143 |
amount_of_predictions = pd_total
|
| 1144 |
if amount_of_predictions == 0:
|
|
|
|
| 1156 |
|
| 1157 |
|
| 1158 |
class NER(CustomF1):
|
| 1159 |
+
def get_element_group(self, element, additional_input):
|
| 1160 |
return element[1]
|
| 1161 |
|
| 1162 |
+
def get_element_representation(self, element, additional_input):
|
| 1163 |
return str(element)
|
| 1164 |
|
| 1165 |
|
|
|
|
| 1185 |
class TokenOverlap(InstanceMetric):
|
| 1186 |
reduction_map = {"mean": ["f1", "precision", "recall"]}
|
| 1187 |
main_score = "f1"
|
| 1188 |
+
ci_scores = ["f1", "precision", "recall"]
|
| 1189 |
|
| 1190 |
def compute(
|
| 1191 |
self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
|
|
|
|
| 1219 |
main_score = "f1"
|
| 1220 |
reduction_map = {"mean": ["f1", "precision", "recall"]}
|
| 1221 |
hf_metric_fields = ["f1", "precision", "recall"]
|
| 1222 |
+
ci_scores = ["f1", "precision", "recall"]
|
| 1223 |
model_name: str
|
| 1224 |
|
| 1225 |
def prepare(self):
|
|
|
|
| 1368 |
]
|
| 1369 |
scores.append(self.eval([q_references], [q_predictions]))
|
| 1370 |
return {self.main_score: mean(scores) if len(scores) > 0 else np.nan}
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
class RetrievalMetric(InstanceMetric):
|
| 1374 |
+
def compute(
|
| 1375 |
+
self, references: List[Any], prediction: Any, additional_inputs: Dict
|
| 1376 |
+
) -> dict:
|
| 1377 |
+
# digest input
|
| 1378 |
+
pred_ids: List[Any] = prediction
|
| 1379 |
+
ref_ids: List[Any] = list(dict.fromkeys(references))
|
| 1380 |
+
|
| 1381 |
+
# relevance_at_k: 1-based dictionary of indicators (0/1), telling whether
|
| 1382 |
+
# the doc id retrieved at position k (assuming it is 1-based, so k starts
|
| 1383 |
+
# from 1) is in the gold doc ids or not.
|
| 1384 |
+
# For example, assuming that in the retrieved docs we have correct predictions
|
| 1385 |
+
# at positions 2, 4 and 5 (1-based), the dict will look like:
|
| 1386 |
+
# {1: 0, 2: 1, 3: 0, 4: 1, 5: 1, ...}
|
| 1387 |
+
relevance_at_k = {
|
| 1388 |
+
k + 1: 1 if doc_id in ref_ids else 0 for k, doc_id in enumerate(pred_ids)
|
| 1389 |
+
}
|
| 1390 |
+
|
| 1391 |
+
# relevance_sum_at_k: 1-based dictionary of counts, where the value at k determines
|
| 1392 |
+
# how many gold doc ids have been observed up to index k.
|
| 1393 |
+
relevance_sum_at_k = {}
|
| 1394 |
+
for k, value in relevance_at_k.items():
|
| 1395 |
+
relevance_sum_at_k[k] = relevance_sum_at_k.get(k - 1, 0) + value
|
| 1396 |
+
|
| 1397 |
+
# precision_at_k: the precision of the top k retrieved documents. For example,
|
| 1398 |
+
# assuming that only 1 out of the first 4 retrieved documents is correct, the
|
| 1399 |
+
# value at 4 will be 1/4.
|
| 1400 |
+
precision_at_k = {k: value / k for k, value in relevance_sum_at_k.items()}
|
| 1401 |
+
|
| 1402 |
+
# recall_at_k: the recall of the top k retrieved documents. For example,
|
| 1403 |
+
# assuming that only 2 out of the 3 gold documents are in the top 5 results,
|
| 1404 |
+
# the value at 5 will be 2/3.
|
| 1405 |
+
n_refs = len(ref_ids)
|
| 1406 |
+
recall_at_k = {
|
| 1407 |
+
k: value / n_refs if n_refs > 0 else 0
|
| 1408 |
+
for k, value in relevance_sum_at_k.items()
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
# rank - the 1-based index of the first hit of a gold doc id. So 1
|
| 1412 |
+
# means first position.
|
| 1413 |
+
rank = 0
|
| 1414 |
+
for k, relevance in relevance_at_k.items():
|
| 1415 |
+
if relevance == 1:
|
| 1416 |
+
rank = k
|
| 1417 |
+
break
|
| 1418 |
+
|
| 1419 |
+
# match_at_k: whether we have a match at the top k retrieved documents
|
| 1420 |
+
match_at_k = {
|
| 1421 |
+
k: 1.0 if value > 0 else 0.0 for k, value in relevance_sum_at_k.items()
|
| 1422 |
+
}
|
| 1423 |
+
|
| 1424 |
+
return self._compute(
|
| 1425 |
+
relevance_at_k,
|
| 1426 |
+
relevance_sum_at_k,
|
| 1427 |
+
precision_at_k,
|
| 1428 |
+
recall_at_k,
|
| 1429 |
+
match_at_k,
|
| 1430 |
+
rank,
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
@abstractmethod
|
| 1434 |
+
def _compute(
|
| 1435 |
+
self,
|
| 1436 |
+
relevance_at_k,
|
| 1437 |
+
relevance_sum_at_k,
|
| 1438 |
+
precision_at_k,
|
| 1439 |
+
recall_at_k,
|
| 1440 |
+
match_at_k,
|
| 1441 |
+
rank,
|
| 1442 |
+
) -> dict:
|
| 1443 |
+
pass
|
| 1444 |
+
|
| 1445 |
+
|
| 1446 |
+
class MRR(RetrievalMetric):
|
| 1447 |
+
reduction_map = {"mean": ["mrr"]}
|
| 1448 |
+
main_score = "mrr"
|
| 1449 |
+
|
| 1450 |
+
def _compute(
|
| 1451 |
+
self,
|
| 1452 |
+
relevance_at_k,
|
| 1453 |
+
relevance_sum_at_k,
|
| 1454 |
+
precision_at_k,
|
| 1455 |
+
recall_at_k,
|
| 1456 |
+
match_at_k,
|
| 1457 |
+
rank,
|
| 1458 |
+
) -> dict:
|
| 1459 |
+
return {self.main_score: 1 / rank if rank > 0 else 0}
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
class MAP(RetrievalMetric):
|
| 1463 |
+
reduction_map = {"mean": ["map"]}
|
| 1464 |
+
main_score = "map"
|
| 1465 |
+
|
| 1466 |
+
def _compute(
|
| 1467 |
+
self,
|
| 1468 |
+
relevance_at_k,
|
| 1469 |
+
relevance_sum_at_k,
|
| 1470 |
+
precision_at_k,
|
| 1471 |
+
recall_at_k,
|
| 1472 |
+
match_at_k,
|
| 1473 |
+
rank,
|
| 1474 |
+
) -> dict:
|
| 1475 |
+
result = 0
|
| 1476 |
+
if len(relevance_at_k) > 0:
|
| 1477 |
+
total = sum(relevance_at_k.values())
|
| 1478 |
+
if total > 0:
|
| 1479 |
+
dot = sum(relevance_at_k[k] * precision_at_k[k] for k in relevance_at_k)
|
| 1480 |
+
result = dot / total
|
| 1481 |
+
return {self.main_score: result}
|
| 1482 |
+
|
| 1483 |
+
|
| 1484 |
+
class RetrievalAtK(RetrievalMetric):
|
| 1485 |
+
k_list: List[int]
|
| 1486 |
+
main_score: str = None
|
| 1487 |
+
reduction_map: Dict[str, List[str]] = None
|
| 1488 |
+
|
| 1489 |
+
def prepare(self):
|
| 1490 |
+
super().prepare()
|
| 1491 |
+
self.main_score = self.score_name("match", self.k_list[0])
|
| 1492 |
+
self.ci_scores = [
|
| 1493 |
+
self.score_name(measure, k)
|
| 1494 |
+
for measure in ["precision", "recall", "match"]
|
| 1495 |
+
for k in self.k_list
|
| 1496 |
+
]
|
| 1497 |
+
self.reduction_map = {"mean": self.ci_scores}
|
| 1498 |
+
|
| 1499 |
+
@staticmethod
|
| 1500 |
+
def score_name(measure: str, k: int):
|
| 1501 |
+
return f"{measure}_at_{k}"
|
| 1502 |
+
|
| 1503 |
+
def _compute(
|
| 1504 |
+
self,
|
| 1505 |
+
relevance_at_k,
|
| 1506 |
+
relevance_sum_at_k,
|
| 1507 |
+
precision_at_k,
|
| 1508 |
+
recall_at_k,
|
| 1509 |
+
match_at_k,
|
| 1510 |
+
rank,
|
| 1511 |
+
) -> dict:
|
| 1512 |
+
result = {}
|
| 1513 |
+
for measure_array, measure_name in [
|
| 1514 |
+
(precision_at_k, "precision"),
|
| 1515 |
+
(recall_at_k, "recall"),
|
| 1516 |
+
(match_at_k, "match"),
|
| 1517 |
+
]:
|
| 1518 |
+
max_k = max(measure_array.keys())
|
| 1519 |
+
for k in self.k_list:
|
| 1520 |
+
result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
|
| 1521 |
+
return result
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
class KPA(CustomF1):
|
| 1525 |
+
def get_element_group(self, element, additional_input):
|
| 1526 |
+
return additional_input["keypoint"]
|
| 1527 |
+
|
| 1528 |
+
def get_element_representation(self, element, additional_input):
|
| 1529 |
+
return additional_input["keypoint"]
|
| 1530 |
+
|
| 1531 |
+
def should_ignore_element(self, element, additional_input):
|
| 1532 |
+
return element == "none"
|