ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union
from pie_datasets import Dataset, DatasetDict
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import Document
from pytorch_ie.annotations import BinaryRelation, Span
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
logger = logging.getLogger(__name__)
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
def apply_func_to_splits(
dataset: DatasetDict,
function: Union[str, Callable],
result_document_type: Type[Document],
**kwargs,
):
resolved_func = resolve_target(function)
resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
result_dict = dict()
split: Dataset
for split_name, split in dataset.items():
converted_dataset = split.map(
function=resolved_func,
batched=True,
batch_size=len(split),
result_document_type=resolved_document_type,
**kwargs,
)
result_dict[split_name] = converted_dataset
return DatasetDict(result_dict)
S = TypeVar("S", bound=Span)
def shift_span(span: S, offset: int) -> S:
"""Shift the start and end of a span by a given offset."""
return span.copy(start=span.start + offset, end=span.end + offset)
D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations)
def add_predicted_semantically_same_relations_to_document(
document: D,
doc_id2docs_with_predictions: Dict[
str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
],
relation_label: str,
argument_label_blacklist: Optional[List[str]] = None,
verbose: bool = False,
) -> D:
# create lookup for detached versions of the spans (attached span != detached span even if they are the same)
span2span = {span.copy(): span for span in document.labeled_spans}
for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []):
offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"]
offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"]
for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions:
head = shift_span(coref_rel.head, offset=offset)
if head not in span2span:
if verbose:
logger.warning(f"doc_id={document.id}: Head span {head} not found.")
continue
tail = shift_span(coref_rel.tail, offset=offset_pair)
if tail not in span2span:
if verbose:
logger.warning(f"doc_id={document.id}: Tail span {tail} not found.")
continue
if argument_label_blacklist is not None and (
span2span[head].label in argument_label_blacklist
or span2span[tail].label in argument_label_blacklist
):
continue
new_rel = BinaryRelation(
head=span2span[head],
tail=span2span[tail],
label=relation_label,
score=coref_rel.score,
)
document.binary_relations.predictions.append(new_rel)
return document
def integrate_coref_predictions_from_text_pair_documents(
dataset: DatasetDict, data_dir: str, **kwargs
) -> DatasetDict:
dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir)
for split_name in dataset.keys():
ds_with_predictions = dataset_with_predictions[split_name]
original_doc_id2docs = defaultdict(list)
for doc in ds_with_predictions:
original_doc_id = doc.metadata["original_doc_id"]
if original_doc_id != doc.metadata["original_doc_id_pair"]:
raise ValueError(
f"Original document IDs do not match: "
f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. "
f"Cross-document coref is not supported."
)
original_doc_id2docs[original_doc_id].append(doc)
dataset[split_name] = dataset[split_name].map(
function=add_predicted_semantically_same_relations_to_document,
fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs),
)
return dataset