File size: 1,673 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import logging
from collections import Counter
from typing import Dict, List, TypeVar

from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic
from pytorch_ie.annotations import BinaryRelation

from src.utils.graph_utils import get_connected_components

logger = logging.getLogger(__name__)

A = TypeVar("A")


# TODO: remove when "counts" aggregation function is available in DocumentStatistic
def count_func(values: List[int]) -> Dict[int, int]:
    """Counts the number of occurrences of each value in the list."""
    counter = Counter(values)
    result = {k: counter[k] for k in sorted(counter)}
    return result


class ConnectedComponentSizes(DocumentStatistic):
    # TODO: use "counts" aggregation function when available in DocumentStatistic
    DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"]

    def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.relation_layer = relation_layer
        self.link_relation_label = link_relation_label

    def _collect(self, document: Document) -> List[int]:
        relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer]
        spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer

        connected_components: List[List] = get_connected_components(
            elements=spans,
            relations=relations,
            link_relation_label=self.link_relation_label,
            add_singletons=True,
        )
        new_component_sizes = [len(component) for component in connected_components]
        return new_component_sizes