update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from collections import defaultdict | |
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union | |
from pie_datasets import Dataset, DatasetDict | |
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
from pytorch_ie import Document | |
from pytorch_ie.annotations import BinaryRelation, Span | |
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations | |
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target | |
logger = logging.getLogger(__name__) | |
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and | |
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged | |
def apply_func_to_splits( | |
dataset: DatasetDict, | |
function: Union[str, Callable], | |
result_document_type: Type[Document], | |
**kwargs, | |
): | |
resolved_func = resolve_target(function) | |
resolved_document_type = resolve_optional_document_type(document_type=result_document_type) | |
result_dict = dict() | |
split: Dataset | |
for split_name, split in dataset.items(): | |
converted_dataset = split.map( | |
function=resolved_func, | |
batched=True, | |
batch_size=len(split), | |
result_document_type=resolved_document_type, | |
**kwargs, | |
) | |
result_dict[split_name] = converted_dataset | |
return DatasetDict(result_dict) | |
S = TypeVar("S", bound=Span) | |
def shift_span(span: S, offset: int) -> S: | |
"""Shift the start and end of a span by a given offset.""" | |
return span.copy(start=span.start + offset, end=span.end + offset) | |
D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations) | |
def add_predicted_semantically_same_relations_to_document( | |
document: D, | |
doc_id2docs_with_predictions: Dict[ | |
str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
], | |
relation_label: str, | |
argument_label_blacklist: Optional[List[str]] = None, | |
verbose: bool = False, | |
) -> D: | |
# create lookup for detached versions of the spans (attached span != detached span even if they are the same) | |
span2span = {span.copy(): span for span in document.labeled_spans} | |
for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []): | |
offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"] | |
offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"] | |
for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions: | |
head = shift_span(coref_rel.head, offset=offset) | |
if head not in span2span: | |
if verbose: | |
logger.warning(f"doc_id={document.id}: Head span {head} not found.") | |
continue | |
tail = shift_span(coref_rel.tail, offset=offset_pair) | |
if tail not in span2span: | |
if verbose: | |
logger.warning(f"doc_id={document.id}: Tail span {tail} not found.") | |
continue | |
if argument_label_blacklist is not None and ( | |
span2span[head].label in argument_label_blacklist | |
or span2span[tail].label in argument_label_blacklist | |
): | |
continue | |
new_rel = BinaryRelation( | |
head=span2span[head], | |
tail=span2span[tail], | |
label=relation_label, | |
score=coref_rel.score, | |
) | |
document.binary_relations.predictions.append(new_rel) | |
return document | |
def integrate_coref_predictions_from_text_pair_documents( | |
dataset: DatasetDict, data_dir: str, **kwargs | |
) -> DatasetDict: | |
dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir) | |
for split_name in dataset.keys(): | |
ds_with_predictions = dataset_with_predictions[split_name] | |
original_doc_id2docs = defaultdict(list) | |
for doc in ds_with_predictions: | |
original_doc_id = doc.metadata["original_doc_id"] | |
if original_doc_id != doc.metadata["original_doc_id_pair"]: | |
raise ValueError( | |
f"Original document IDs do not match: " | |
f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. " | |
f"Cross-document coref is not supported." | |
) | |
original_doc_id2docs[original_doc_id].append(doc) | |
dataset[split_name] = dataset[split_name].map( | |
function=add_predicted_semantically_same_relations_to_document, | |
fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs), | |
) | |
return dataset | |