File size: 7,640 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import logging
from collections import defaultdict
from functools import partial
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    Hashable,
    List,
    Optional,
    Tuple,
    TypeAlias,
    Union,
)

from pytorch_ie.core import Annotation, Document, DocumentMetric
from pytorch_ie.utils.hydra import resolve_target

from src.document.types import RelatedRelation

logger = logging.getLogger(__name__)


def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
    return getattr(ann, label_field) in labels


def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
    return getattr(ann, label_field) == label


InstanceType: TypeAlias = Tuple[Document, Annotation]
InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]]


class TPFFPFNMetric(DocumentMetric):
    """Computes the lists of True Positive, False Positive, and False Negative
    annotations for a given layer. If labels are provided, it also computes
    the counts for each label separately.

    Works only with `RelatedRelation` annotations for now.

    Args:
        layer: The layer to compute the metrics for.
        labels: If provided, calculate metrics for each label.
        label_field: The field to use for the label. Defaults to "label".
    """

    def __init__(
        self,
        layer: str,
        labels: Optional[Union[Collection[str], str]] = None,
        label_field: str = "label",
        annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
    ):
        super().__init__()
        self.layer = layer
        self.label_field = label_field
        self.annotation_processor: Optional[Callable[[Annotation], Hashable]]
        if isinstance(annotation_processor, str):
            self.annotation_processor = resolve_target(annotation_processor)
        else:
            self.annotation_processor = annotation_processor

        self.per_label = labels is not None
        self.infer_labels = False
        if self.per_label:
            if isinstance(labels, str):
                if labels != "INFERRED":
                    raise ValueError(
                        "labels can only be 'INFERRED' if per_label is True and labels is a string"
                    )
                self.labels = []
                self.infer_labels = True
            elif isinstance(labels, Collection):
                if not all(isinstance(label, str) for label in labels):
                    raise ValueError("labels must be a collection of strings")
                if "MICRO" in labels or "MACRO" in labels:
                    raise ValueError(
                        "labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
                    )
                if len(labels) == 0:
                    raise ValueError("labels cannot be empty")
                self.labels = list(labels)
            else:
                raise ValueError("labels must be a string or a collection of strings")

    def reset(self):
        self.tp_fp_fn = defaultdict(lambda: (list(), list(), list()))

    def get_tp_fp_fn(
        self,
        document: Document,
        annotation_filter: Optional[Callable[[Annotation], bool]] = None,
        annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
    ) -> InstancesType:
        annotation_processor = annotation_processor or (lambda ann: ann)
        annotation_filter = annotation_filter or (lambda ann: True)
        predicted_annotations = {
            annotation_processor(ann)
            for ann in document[self.layer].predictions
            if annotation_filter(ann)
        }
        gold_annotations = {
            annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
        }
        tp = [(document, ann) for ann in predicted_annotations & gold_annotations]
        fn = [(document, ann) for ann in gold_annotations - predicted_annotations]
        fp = [(document, ann) for ann in predicted_annotations - gold_annotations]
        return tp, fp, fn

    def add_annotations(self, annotations: InstancesType, label: str):
        self.tp_fp_fn[label] = (
            self.tp_fp_fn[label][0] + annotations[0],
            self.tp_fp_fn[label][1] + annotations[1],
            self.tp_fp_fn[label][2] + annotations[2],
        )

    def _update(self, document: Document):
        new_tp_fp_fn = self.get_tp_fp_fn(
            document=document,
            annotation_filter=(
                partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
                if self.per_label and not self.infer_labels
                else None
            ),
            annotation_processor=self.annotation_processor,
        )
        self.add_annotations(new_tp_fp_fn, label="MICRO")
        if self.infer_labels:
            layer = document[self.layer]
            # collect labels from gold data and predictions
            for ann in list(layer) + list(layer.predictions):
                label = getattr(ann, self.label_field)
                if label not in self.labels:
                    self.labels.append(label)
        if self.per_label:
            for label in self.labels:
                new_tp_fp_fn = self.get_tp_fp_fn(
                    document=document,
                    annotation_filter=partial(
                        has_this_label, label_field=self.label_field, label=label
                    ),
                    annotation_processor=self.annotation_processor,
                )
                self.add_annotations(new_tp_fp_fn, label=label)

    def format_texts(self, texts: List[str]) -> str:
        return "<SEP>".join(texts)

    def format_annotation(self, ann: Annotation) -> Dict[str, Any]:
        if isinstance(ann, RelatedRelation):

            head_resolved = ann.head.resolve()
            tail_resolved = ann.tail.resolve()
            ref_resolved = ann.reference_span.resolve()
            return {
                "related_label": ann.label,
                "related_score": round(ann.score, 3),
                "query_label": head_resolved[0],
                "query_texts": self.format_texts(head_resolved[1]),
                "query_score": round(ann.head.score, 3),
                "ref_label": ref_resolved[0],
                "ref_texts": self.format_texts(ref_resolved[1]),
                "ref_score": round(ann.reference_span.score, 3),
                "rec_label": tail_resolved[0],
                "rec_texts": self.format_texts(tail_resolved[1]),
                "rec_score": round(ann.tail.score, 3),
            }
        else:
            raise NotImplementedError
            # return ann.resolve()

    def format_instance(self, instance: InstanceType) -> Dict[str, Any]:
        document, annotation = instance
        result = self.format_annotation(annotation)
        if getattr(document, "id", None) is not None:
            result["document_id"] = document.id
        return result

    def _compute(self) -> Dict[str, Dict[str, list]]:
        res = dict()
        for k, instances in self.tp_fp_fn.items():
            res[k] = {
                "tp": [self.format_instance(instance) for instance in instances[0]],
                "fp": [self.format_instance(instance) for instance in instances[1]],
                "fn": [self.format_instance(instance) for instance in instances[2]],
            }

        # if self.show_as_markdown:
        #    logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
        return res