File size: 7,640 Bytes
d868d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import logging
from collections import defaultdict
from functools import partial
from typing import (
Any,
Callable,
Collection,
Dict,
Hashable,
List,
Optional,
Tuple,
TypeAlias,
Union,
)
from pytorch_ie.core import Annotation, Document, DocumentMetric
from pytorch_ie.utils.hydra import resolve_target
from src.document.types import RelatedRelation
logger = logging.getLogger(__name__)
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
return getattr(ann, label_field) in labels
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
return getattr(ann, label_field) == label
InstanceType: TypeAlias = Tuple[Document, Annotation]
InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]]
class TPFFPFNMetric(DocumentMetric):
"""Computes the lists of True Positive, False Positive, and False Negative
annotations for a given layer. If labels are provided, it also computes
the counts for each label separately.
Works only with `RelatedRelation` annotations for now.
Args:
layer: The layer to compute the metrics for.
labels: If provided, calculate metrics for each label.
label_field: The field to use for the label. Defaults to "label".
"""
def __init__(
self,
layer: str,
labels: Optional[Union[Collection[str], str]] = None,
label_field: str = "label",
annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
):
super().__init__()
self.layer = layer
self.label_field = label_field
self.annotation_processor: Optional[Callable[[Annotation], Hashable]]
if isinstance(annotation_processor, str):
self.annotation_processor = resolve_target(annotation_processor)
else:
self.annotation_processor = annotation_processor
self.per_label = labels is not None
self.infer_labels = False
if self.per_label:
if isinstance(labels, str):
if labels != "INFERRED":
raise ValueError(
"labels can only be 'INFERRED' if per_label is True and labels is a string"
)
self.labels = []
self.infer_labels = True
elif isinstance(labels, Collection):
if not all(isinstance(label, str) for label in labels):
raise ValueError("labels must be a collection of strings")
if "MICRO" in labels or "MACRO" in labels:
raise ValueError(
"labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
)
if len(labels) == 0:
raise ValueError("labels cannot be empty")
self.labels = list(labels)
else:
raise ValueError("labels must be a string or a collection of strings")
def reset(self):
self.tp_fp_fn = defaultdict(lambda: (list(), list(), list()))
def get_tp_fp_fn(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> InstancesType:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann)
}
gold_annotations = {
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
}
tp = [(document, ann) for ann in predicted_annotations & gold_annotations]
fn = [(document, ann) for ann in gold_annotations - predicted_annotations]
fp = [(document, ann) for ann in predicted_annotations - gold_annotations]
return tp, fp, fn
def add_annotations(self, annotations: InstancesType, label: str):
self.tp_fp_fn[label] = (
self.tp_fp_fn[label][0] + annotations[0],
self.tp_fp_fn[label][1] + annotations[1],
self.tp_fp_fn[label][2] + annotations[2],
)
def _update(self, document: Document):
new_tp_fp_fn = self.get_tp_fp_fn(
document=document,
annotation_filter=(
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
if self.per_label and not self.infer_labels
else None
),
annotation_processor=self.annotation_processor,
)
self.add_annotations(new_tp_fp_fn, label="MICRO")
if self.infer_labels:
layer = document[self.layer]
# collect labels from gold data and predictions
for ann in list(layer) + list(layer.predictions):
label = getattr(ann, self.label_field)
if label not in self.labels:
self.labels.append(label)
if self.per_label:
for label in self.labels:
new_tp_fp_fn = self.get_tp_fp_fn(
document=document,
annotation_filter=partial(
has_this_label, label_field=self.label_field, label=label
),
annotation_processor=self.annotation_processor,
)
self.add_annotations(new_tp_fp_fn, label=label)
def format_texts(self, texts: List[str]) -> str:
return "<SEP>".join(texts)
def format_annotation(self, ann: Annotation) -> Dict[str, Any]:
if isinstance(ann, RelatedRelation):
head_resolved = ann.head.resolve()
tail_resolved = ann.tail.resolve()
ref_resolved = ann.reference_span.resolve()
return {
"related_label": ann.label,
"related_score": round(ann.score, 3),
"query_label": head_resolved[0],
"query_texts": self.format_texts(head_resolved[1]),
"query_score": round(ann.head.score, 3),
"ref_label": ref_resolved[0],
"ref_texts": self.format_texts(ref_resolved[1]),
"ref_score": round(ann.reference_span.score, 3),
"rec_label": tail_resolved[0],
"rec_texts": self.format_texts(tail_resolved[1]),
"rec_score": round(ann.tail.score, 3),
}
else:
raise NotImplementedError
# return ann.resolve()
def format_instance(self, instance: InstanceType) -> Dict[str, Any]:
document, annotation = instance
result = self.format_annotation(annotation)
if getattr(document, "id", None) is not None:
result["document_id"] = document.id
return result
def _compute(self) -> Dict[str, Dict[str, list]]:
res = dict()
for k, instances in self.tp_fp_fn.items():
res[k] = {
"tp": [self.format_instance(instance) for instance in instances[0]],
"fp": [self.format_instance(instance) for instance in instances[1]],
"fn": [self.format_instance(instance) for instance in instances[2]],
}
# if self.show_as_markdown:
# logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
return res
|