Upload metrics.py with huggingface_hub
Browse files- metrics.py +18 -3
metrics.py
CHANGED
|
@@ -220,13 +220,14 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 220 |
metric_name: str = None
|
| 221 |
main_score: str = None
|
| 222 |
scale: float = 1.0
|
|
|
|
| 223 |
|
| 224 |
def prepare(self):
|
| 225 |
super().prepare()
|
| 226 |
self.metric = evaluate.load(self.metric_name)
|
| 227 |
|
| 228 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 229 |
-
result = self.metric.compute(predictions=predictions, references=references)
|
| 230 |
if self.scale != 1.0:
|
| 231 |
for key in result:
|
| 232 |
if isinstance(result[key], float):
|
|
@@ -373,7 +374,14 @@ class Rouge(HuggingfaceMetric):
|
|
| 373 |
main_score = "rougeL"
|
| 374 |
scale = 1.0
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def prepare(self):
|
|
|
|
|
|
|
| 377 |
super().prepare()
|
| 378 |
import nltk
|
| 379 |
|
|
@@ -381,8 +389,9 @@ class Rouge(HuggingfaceMetric):
|
|
| 381 |
self.sent_tokenize = nltk.sent_tokenize
|
| 382 |
|
| 383 |
def compute(self, references, predictions):
|
| 384 |
-
|
| 385 |
-
|
|
|
|
| 386 |
return super().compute(references, predictions)
|
| 387 |
|
| 388 |
|
|
@@ -429,6 +438,12 @@ class Bleu(HuggingfaceMetric):
|
|
| 429 |
scale = 1.0
|
| 430 |
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
class MatthewsCorrelation(HuggingfaceMetric):
|
| 433 |
metric_name = "matthews_correlation"
|
| 434 |
main_score = "matthews_correlation"
|
|
|
|
| 220 |
metric_name: str = None
|
| 221 |
main_score: str = None
|
| 222 |
scale: float = 1.0
|
| 223 |
+
hf_compute_args: dict = {}
|
| 224 |
|
| 225 |
def prepare(self):
|
| 226 |
super().prepare()
|
| 227 |
self.metric = evaluate.load(self.metric_name)
|
| 228 |
|
| 229 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 230 |
+
result = self.metric.compute(predictions=predictions, references=references, **self.hf_compute_args)
|
| 231 |
if self.scale != 1.0:
|
| 232 |
for key in result:
|
| 233 |
if isinstance(result[key], float):
|
|
|
|
| 374 |
main_score = "rougeL"
|
| 375 |
scale = 1.0
|
| 376 |
|
| 377 |
+
use_aggregator: bool = True
|
| 378 |
+
rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
| 379 |
+
|
| 380 |
+
sent_split_newline: bool = True
|
| 381 |
+
|
| 382 |
def prepare(self):
|
| 383 |
+
self.hf_compute_args = {"use_aggregator": self.use_aggregator, "rouge_types": self.rouge_types}
|
| 384 |
+
|
| 385 |
super().prepare()
|
| 386 |
import nltk
|
| 387 |
|
|
|
|
| 389 |
self.sent_tokenize = nltk.sent_tokenize
|
| 390 |
|
| 391 |
def compute(self, references, predictions):
|
| 392 |
+
if self.sent_split_newline:
|
| 393 |
+
predictions = ["\n".join(self.sent_tokenize(prediction.strip())) for prediction in predictions]
|
| 394 |
+
references = [["\n".join(self.sent_tokenize(r.strip())) for r in reference] for reference in references]
|
| 395 |
return super().compute(references, predictions)
|
| 396 |
|
| 397 |
|
|
|
|
| 438 |
scale = 1.0
|
| 439 |
|
| 440 |
|
| 441 |
+
class SacreBleu(HuggingfaceMetric):
|
| 442 |
+
metric_name = "sacrebleu"
|
| 443 |
+
main_score = "score"
|
| 444 |
+
scale = 1.0
|
| 445 |
+
|
| 446 |
+
|
| 447 |
class MatthewsCorrelation(HuggingfaceMetric):
|
| 448 |
metric_name = "matthews_correlation"
|
| 449 |
main_score = "matthews_correlation"
|