File size: 1,407 Bytes
d868d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import Callable, Hashable, Optional, Tuple
from pie_modules.metrics import F1Metric
from pytorch_ie import Annotation, Document
class F1WithThresholdMetric(F1Metric):
def __init__(self, *args, threshold: float = 0.0, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
def calculate_counts(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> Tuple[int, int, int]:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
}
gold_annotations = {
annotation_processor(ann)
for ann in document[self.layer]
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
}
tp = len([ann for ann in predicted_annotations & gold_annotations])
fn = len([ann for ann in gold_annotations - predicted_annotations])
fp = len([ann for ann in predicted_annotations - gold_annotations])
return tp, fp, fn
|