Upload metrics.py with huggingface_hub
Browse files- metrics.py +41 -2
metrics.py
CHANGED
|
@@ -8,6 +8,7 @@ import evaluate
|
|
| 8 |
import nltk
|
| 9 |
import numpy
|
| 10 |
|
|
|
|
| 11 |
from .operator import (
|
| 12 |
MultiStreamOperator,
|
| 13 |
SingleStreamOperator,
|
|
@@ -60,7 +61,13 @@ class GlobalMetric(SingleStreamOperator, Metric):
|
|
| 60 |
|
| 61 |
refs, pred = instance["references"], instance["prediction"]
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
instance["score"]["instance"].update(instance_score)
|
| 65 |
|
| 66 |
references.append(refs)
|
|
@@ -355,8 +362,27 @@ class Bleu(HuggingfaceMetric):
|
|
| 355 |
scale = 1.0
|
| 356 |
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
class CustomF1(GlobalMetric):
|
| 359 |
main_score = "f1_micro"
|
|
|
|
| 360 |
|
| 361 |
@abstractmethod
|
| 362 |
def get_element_group(self, element):
|
|
@@ -391,6 +417,10 @@ class CustomF1(GlobalMetric):
|
|
| 391 |
assert len(references) == len(predictions), (
|
| 392 |
f"references size ({len(references)})" f" doesn't mach predictions sise ({len(references)})."
|
| 393 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
groups_statistics = dict()
|
| 395 |
for references_batch, predictions_batch in zip(references, predictions):
|
| 396 |
grouped_references = self.group_elements(references_batch)
|
|
@@ -418,6 +448,7 @@ class CustomF1(GlobalMetric):
|
|
| 418 |
groups_statistics[group]["recall_denominator"] += rd
|
| 419 |
|
| 420 |
result = {}
|
|
|
|
| 421 |
pn_total = pd_total = rn_total = rd_total = 0
|
| 422 |
for group in groups_statistics.keys():
|
| 423 |
pn, pd, rn, rd = (
|
|
@@ -426,13 +457,21 @@ class CustomF1(GlobalMetric):
|
|
| 426 |
groups_statistics[group]["recall_numerator"],
|
| 427 |
groups_statistics[group]["recall_denominator"],
|
| 428 |
)
|
| 429 |
-
result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
|
| 430 |
pn_total, pd_total, rn_total, rd_total = pn_total + pn, pd_total + pd, rn_total + rn, rd_total + rd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
try:
|
| 432 |
result["f1_macro"] = sum(result.values()) / len(result.keys())
|
| 433 |
except ZeroDivisionError:
|
| 434 |
result["f1_macro"] = 1.0
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
result[f"f1_micro"] = self.f1(pn_total, pd_total, rn_total, rd_total)
|
| 437 |
return result
|
| 438 |
|
|
|
|
| 8 |
import nltk
|
| 9 |
import numpy
|
| 10 |
|
| 11 |
+
from .dataclass import InternalField
|
| 12 |
from .operator import (
|
| 13 |
MultiStreamOperator,
|
| 14 |
SingleStreamOperator,
|
|
|
|
| 61 |
|
| 62 |
refs, pred = instance["references"], instance["prediction"]
|
| 63 |
|
| 64 |
+
try:
|
| 65 |
+
instance_score = self._compute([refs], [pred])
|
| 66 |
+
except:
|
| 67 |
+
instance_score = {"score": None}
|
| 68 |
+
if isinstance(self.main_score, str) and self.main_score is not None:
|
| 69 |
+
instance_score[self.main_score] = None
|
| 70 |
+
|
| 71 |
instance["score"]["instance"].update(instance_score)
|
| 72 |
|
| 73 |
references.append(refs)
|
|
|
|
| 362 |
scale = 1.0
|
| 363 |
|
| 364 |
|
| 365 |
+
class MatthewsCorrelation(HuggingfaceMetric):
|
| 366 |
+
metric_name = "matthews_correlation"
|
| 367 |
+
main_score = "matthews_correlation"
|
| 368 |
+
str_to_id: dict = InternalField(default_factory=dict)
|
| 369 |
+
|
| 370 |
+
def get_str_id(self, str):
|
| 371 |
+
if str not in self.str_to_id:
|
| 372 |
+
id = len(self.str_to_id)
|
| 373 |
+
self.str_to_id[str] = id
|
| 374 |
+
return self.str_to_id[str]
|
| 375 |
+
|
| 376 |
+
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 377 |
+
formatted_references = [self.get_str_id(reference[0]) for reference in references]
|
| 378 |
+
formatted_predictions = [self.get_str_id(prediction) for prediction in predictions]
|
| 379 |
+
result = self.metric.compute(predictions=formatted_predictions, references=formatted_references)
|
| 380 |
+
return result
|
| 381 |
+
|
| 382 |
+
|
| 383 |
class CustomF1(GlobalMetric):
|
| 384 |
main_score = "f1_micro"
|
| 385 |
+
classes = None
|
| 386 |
|
| 387 |
@abstractmethod
|
| 388 |
def get_element_group(self, element):
|
|
|
|
| 417 |
assert len(references) == len(predictions), (
|
| 418 |
f"references size ({len(references)})" f" doesn't mach predictions sise ({len(references)})."
|
| 419 |
)
|
| 420 |
+
if self.classes is None:
|
| 421 |
+
classes = set([self.get_element_group(e) for sublist in references for e in sublist])
|
| 422 |
+
else:
|
| 423 |
+
classes = self.classes
|
| 424 |
groups_statistics = dict()
|
| 425 |
for references_batch, predictions_batch in zip(references, predictions):
|
| 426 |
grouped_references = self.group_elements(references_batch)
|
|
|
|
| 448 |
groups_statistics[group]["recall_denominator"] += rd
|
| 449 |
|
| 450 |
result = {}
|
| 451 |
+
num_of_unknown_class_predictions = 0
|
| 452 |
pn_total = pd_total = rn_total = rd_total = 0
|
| 453 |
for group in groups_statistics.keys():
|
| 454 |
pn, pd, rn, rd = (
|
|
|
|
| 457 |
groups_statistics[group]["recall_numerator"],
|
| 458 |
groups_statistics[group]["recall_denominator"],
|
| 459 |
)
|
|
|
|
| 460 |
pn_total, pd_total, rn_total, rd_total = pn_total + pn, pd_total + pd, rn_total + rn, rd_total + rd
|
| 461 |
+
if group in classes:
|
| 462 |
+
result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
|
| 463 |
+
else:
|
| 464 |
+
num_of_unknown_class_predictions += pd
|
| 465 |
try:
|
| 466 |
result["f1_macro"] = sum(result.values()) / len(result.keys())
|
| 467 |
except ZeroDivisionError:
|
| 468 |
result["f1_macro"] = 1.0
|
| 469 |
|
| 470 |
+
amount_of_predictions = pd_total
|
| 471 |
+
if amount_of_predictions == 0:
|
| 472 |
+
result["in_classes_support"] = 1.0
|
| 473 |
+
else:
|
| 474 |
+
result["in_classes_support"] = 1.0 - num_of_unknown_class_predictions / amount_of_predictions
|
| 475 |
result[f"f1_micro"] = self.f1(pn_total, pd_total, rn_total, rd_total)
|
| 476 |
return result
|
| 477 |
|