File size: 4,654 Bytes
d868d2e
 
 
3133b5e
 
d868d2e
3133b5e
d868d2e
 
3133b5e
 
d868d2e
 
3133b5e
 
 
 
 
 
 
d868d2e
3133b5e
 
 
 
 
 
 
 
 
 
 
d868d2e
3133b5e
 
 
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union

from pie_datasets import Dataset, DatasetDict
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import Document
from pytorch_ie.annotations import BinaryRelation, Span
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target

logger = logging.getLogger(__name__)


# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
#  batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
def apply_func_to_splits(
    dataset: DatasetDict,
    function: Union[str, Callable],
    result_document_type: Type[Document],
    **kwargs,
):
    resolved_func = resolve_target(function)
    resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
    result_dict = dict()
    split: Dataset
    for split_name, split in dataset.items():
        converted_dataset = split.map(
            function=resolved_func,
            batched=True,
            batch_size=len(split),
            result_document_type=resolved_document_type,
            **kwargs,
        )
        result_dict[split_name] = converted_dataset
    return DatasetDict(result_dict)


S = TypeVar("S", bound=Span)


def shift_span(span: S, offset: int) -> S:
    """Shift the start and end of a span by a given offset."""
    return span.copy(start=span.start + offset, end=span.end + offset)


D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations)


def add_predicted_semantically_same_relations_to_document(
    document: D,
    doc_id2docs_with_predictions: Dict[
        str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
    ],
    relation_label: str,
    argument_label_blacklist: Optional[List[str]] = None,
    verbose: bool = False,
) -> D:

    # create lookup for detached versions of the spans (attached span != detached span even if they are the same)
    span2span = {span.copy(): span for span in document.labeled_spans}
    for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []):
        offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"]
        offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"]
        for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions:
            head = shift_span(coref_rel.head, offset=offset)
            if head not in span2span:
                if verbose:
                    logger.warning(f"doc_id={document.id}: Head span {head} not found.")
                continue
            tail = shift_span(coref_rel.tail, offset=offset_pair)
            if tail not in span2span:
                if verbose:
                    logger.warning(f"doc_id={document.id}: Tail span {tail} not found.")
                continue
            if argument_label_blacklist is not None and (
                span2span[head].label in argument_label_blacklist
                or span2span[tail].label in argument_label_blacklist
            ):
                continue
            new_rel = BinaryRelation(
                head=span2span[head],
                tail=span2span[tail],
                label=relation_label,
                score=coref_rel.score,
            )
            document.binary_relations.predictions.append(new_rel)
    return document


def integrate_coref_predictions_from_text_pair_documents(
    dataset: DatasetDict, data_dir: str, **kwargs
) -> DatasetDict:

    dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir)

    for split_name in dataset.keys():
        ds_with_predictions = dataset_with_predictions[split_name]
        original_doc_id2docs = defaultdict(list)
        for doc in ds_with_predictions:
            original_doc_id = doc.metadata["original_doc_id"]
            if original_doc_id != doc.metadata["original_doc_id_pair"]:
                raise ValueError(
                    f"Original document IDs do not match: "
                    f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. "
                    f"Cross-document coref is not supported."
                )
            original_doc_id2docs[original_doc_id].append(doc)

        dataset[split_name] = dataset[split_name].map(
            function=add_predicted_semantically_same_relations_to_document,
            fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs),
        )
    return dataset