import logging from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar, Union from pie_datasets import Dataset, DatasetDict from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations from pytorch_ie import Document from pytorch_ie.annotations import BinaryRelation, Span from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target logger = logging.getLogger(__name__) # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged def apply_func_to_splits( dataset: DatasetDict, function: Union[str, Callable], result_document_type: Type[Document], **kwargs, ): resolved_func = resolve_target(function) resolved_document_type = resolve_optional_document_type(document_type=result_document_type) result_dict = dict() split: Dataset for split_name, split in dataset.items(): converted_dataset = split.map( function=resolved_func, batched=True, batch_size=len(split), result_document_type=resolved_document_type, **kwargs, ) result_dict[split_name] = converted_dataset return DatasetDict(result_dict) S = TypeVar("S", bound=Span) def shift_span(span: S, offset: int) -> S: """Shift the start and end of a span by a given offset.""" return span.copy(start=span.start + offset, end=span.end + offset) D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations) def add_predicted_semantically_same_relations_to_document( document: D, doc_id2docs_with_predictions: Dict[ str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations ], relation_label: str, argument_label_blacklist: Optional[List[str]] = None, verbose: bool = False, ) -> D: # create lookup for detached versions of the spans (attached span != detached span even if they are the same) span2span = {span.copy(): span for span in document.labeled_spans} for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []): offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"] offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"] for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions: head = shift_span(coref_rel.head, offset=offset) if head not in span2span: if verbose: logger.warning(f"doc_id={document.id}: Head span {head} not found.") continue tail = shift_span(coref_rel.tail, offset=offset_pair) if tail not in span2span: if verbose: logger.warning(f"doc_id={document.id}: Tail span {tail} not found.") continue if argument_label_blacklist is not None and ( span2span[head].label in argument_label_blacklist or span2span[tail].label in argument_label_blacklist ): continue new_rel = BinaryRelation( head=span2span[head], tail=span2span[tail], label=relation_label, score=coref_rel.score, ) document.binary_relations.predictions.append(new_rel) return document def integrate_coref_predictions_from_text_pair_documents( dataset: DatasetDict, data_dir: str, **kwargs ) -> DatasetDict: dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir) for split_name in dataset.keys(): ds_with_predictions = dataset_with_predictions[split_name] original_doc_id2docs = defaultdict(list) for doc in ds_with_predictions: original_doc_id = doc.metadata["original_doc_id"] if original_doc_id != doc.metadata["original_doc_id_pair"]: raise ValueError( f"Original document IDs do not match: " f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. " f"Cross-document coref is not supported." ) original_doc_id2docs[original_doc_id].append(doc) dataset[split_name] = dataset[split_name].map( function=add_predicted_semantically_same_relations_to_document, fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs), ) return dataset