update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import json | |
import logging | |
import os | |
import tempfile | |
from pathlib import Path | |
from typing import Iterable, List, Optional, Sequence | |
import gradio as gr | |
import pandas as pd | |
from acl_anthology import Anthology | |
from pie_datasets import Dataset, IterableDataset, load_dataset | |
from pytorch_ie import Pipeline | |
from pytorch_ie.documents import ( | |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, | |
) | |
from tqdm import tqdm | |
from src.demo.annotation_utils import create_documents, get_merger | |
from src.demo.data_utils import load_text_from_arxiv | |
from src.demo.rendering_utils import ( | |
RENDER_WITH_DISPLACY, | |
RENDER_WITH_PRETTY_TABLE, | |
render_displacy, | |
render_pretty_table, | |
) | |
from src.demo.retriever_utils import get_text_spans_and_relations_from_document | |
from src.langchain_modules import ( | |
DocumentAwareSpanRetriever, | |
DocumentAwareSpanRetrieverWithRelations, | |
) | |
from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers | |
from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader | |
logger = logging.getLogger(__name__) | |
def add_annotated_pie_documents( | |
retriever: DocumentAwareSpanRetriever, | |
pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], | |
use_predicted_annotations: bool, | |
verbose: bool = False, | |
) -> None: | |
if verbose: | |
gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") | |
num_docs_before = len(retriever.docstore) | |
retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) | |
# number of documents that were overwritten | |
num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) | |
# warn if documents were overwritten | |
if num_overwritten_docs > 0: | |
gr.Warning(f"{num_overwritten_docs} documents were overwritten.") | |
def process_texts( | |
texts: Iterable[str], | |
doc_ids: Iterable[str], | |
argumentation_model: Optional[Pipeline], | |
retriever: DocumentAwareSpanRetriever, | |
split_regex_escaped: Optional[str], | |
handle_parts_of_same: bool = False, | |
verbose: bool = False, | |
) -> None: | |
# check that doc_ids are unique | |
if len(set(doc_ids)) != len(list(doc_ids)): | |
raise gr.Error("Document IDs must be unique.") | |
pie_documents = create_documents( | |
texts=texts, | |
doc_ids=doc_ids, | |
split_regex=split_regex_escaped, | |
) | |
if argumentation_model is not None: | |
if verbose: | |
gr.Info(f"Annotate {len(pie_documents)} documents...") | |
pie_documents = argumentation_model(pie_documents, inplace=True) | |
else: | |
gr.Warning( | |
"Annotation is disabled (no model was loaded). No annotations will be added to the documents." | |
) | |
# this needs to be done also if the documents are not annotated because | |
# it adjusts the document type | |
if handle_parts_of_same: | |
merger = get_merger() | |
pie_documents = [merger(document) for document in pie_documents] | |
add_annotated_pie_documents( | |
retriever=retriever, | |
pie_documents=pie_documents, | |
use_predicted_annotations=True, | |
verbose=verbose, | |
) | |
def add_annotated_pie_documents_from_dataset( | |
retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs | |
) -> None: | |
try: | |
gr.Info( | |
"Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) | |
) | |
dataset = load_dataset(**load_dataset_kwargs) | |
if not isinstance(dataset, (Dataset, IterableDataset)): | |
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") | |
try: | |
dataset_converted = dataset.to_document_type( | |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions | |
) | |
except ValueError: | |
gr.Warning( | |
"The dataset does not seem to have registered converter to create multi-spans. " | |
"Try to Load as single-spans and to convert to multi-spans manually ..." | |
) | |
dataset_converted_single_span = dataset.to_document_type( | |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions | |
) | |
merger = get_merger() | |
dataset_converted = dataset_converted_single_span.map( | |
merger, | |
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
) | |
def _clear_metadata( | |
doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions: | |
result = doc.copy() | |
result.metadata = dict() | |
return result | |
# adding documents with different metadata format to the retriever breaks it, | |
# so we clear the metadata field beforehand | |
dataset_converted_without_metadata = dataset_converted.map( | |
_clear_metadata, | |
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
) | |
add_annotated_pie_documents( | |
retriever=retriever, | |
pie_documents=dataset_converted_without_metadata, | |
use_predicted_annotations=False, | |
verbose=verbose, | |
) | |
except Exception as e: | |
raise gr.Error(f"Failed to load dataset: {e}") | |
def wrapped_process_text( | |
doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs | |
) -> str: | |
try: | |
process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) | |
except Exception as e: | |
raise gr.Error(f"Failed to process text: {e}") | |
# Return as dict and document to avoid serialization issues | |
return doc_id | |
def process_uploaded_files( | |
file_names: List[str], | |
retriever: DocumentAwareSpanRetriever, | |
layer_captions: dict[str, str], | |
**kwargs, | |
) -> pd.DataFrame: | |
try: | |
doc_ids = [] | |
texts = [] | |
for file_name in file_names: | |
if file_name.lower().endswith(".txt"): | |
# read the file content | |
with open(file_name, "r", encoding="utf-8") as f: | |
text = f.read() | |
base_file_name = os.path.basename(file_name) | |
doc_ids.append(base_file_name) | |
texts.append(text) | |
else: | |
raise gr.Error(f"Unsupported file format: {file_name}") | |
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
except Exception as e: | |
raise gr.Error(f"Failed to process uploaded files: {e}") | |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
def process_uploaded_pdf_files( | |
pdf_fulltext_extractor: Optional[FulltextExtractor], | |
file_names: List[str], | |
retriever: DocumentAwareSpanRetriever, | |
layer_captions: dict[str, str], | |
**kwargs, | |
) -> pd.DataFrame: | |
try: | |
if pdf_fulltext_extractor is None: | |
raise gr.Error("PDF fulltext extractor is not available.") | |
doc_ids = [] | |
texts = [] | |
for file_name in file_names: | |
if file_name.lower().endswith(".pdf"): | |
# extract the fulltext from the pdf | |
text_and_extraction_data = pdf_fulltext_extractor(file_name) | |
if text_and_extraction_data is None: | |
raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}") | |
text, _ = text_and_extraction_data | |
base_file_name = os.path.basename(file_name) | |
doc_ids.append(base_file_name) | |
texts.append(text) | |
else: | |
raise gr.Error(f"Unsupported file format: {file_name}") | |
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
except Exception as e: | |
raise gr.Error(f"Failed to process uploaded files: {e}") | |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
def load_acl_anthology_venues( | |
venues: List[str], | |
pdf_fulltext_extractor: Optional[FulltextExtractor], | |
retriever: DocumentAwareSpanRetriever, | |
layer_captions: dict[str, str], | |
acl_anthology_data_dir: Optional[str], | |
pdf_output_dir: Optional[str], | |
show_progress: bool = True, | |
**kwargs, | |
) -> pd.DataFrame: | |
try: | |
if pdf_fulltext_extractor is None: | |
raise gr.Error("PDF fulltext extractor is not available.") | |
if acl_anthology_data_dir is None: | |
raise gr.Error("ACL Anthology data directory is not provided.") | |
if pdf_output_dir is None: | |
raise gr.Error("PDF output directory is not provided.") | |
xml2raw_papers = XML2RawPapers( | |
anthology=Anthology(datadir=Path(acl_anthology_data_dir)), | |
venue_id_whitelist=venues, | |
verbose=False, | |
) | |
pdf_downloader = PDFDownloader() | |
doc_ids = [] | |
texts = [] | |
os.makedirs(pdf_output_dir, exist_ok=True) | |
papers = xml2raw_papers() | |
if show_progress: | |
papers_list = list(papers) | |
papers = tqdm(papers_list, desc="extracting fulltext") | |
gr.Info( | |
f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}" | |
) | |
for paper in papers: | |
if paper.url is not None: | |
pdf_save_path = pdf_downloader.download( | |
paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf" | |
) | |
fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path) | |
if fulltext_extraction_output: | |
text, _ = fulltext_extraction_output | |
doc_id = f"aclanthology.org/{paper.name}" | |
doc_ids.append(doc_id) | |
texts.append(text) | |
else: | |
gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}") | |
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
except Exception as e: | |
raise gr.Error(f"Failed to process uploaded files: {e}") | |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
def wrapped_add_annotated_pie_documents_from_dataset( | |
retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs | |
) -> pd.DataFrame: | |
try: | |
add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) | |
except Exception as e: | |
raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") | |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
def download_processed_documents( | |
retriever: DocumentAwareSpanRetriever, | |
file_name: str = "retriever_store", | |
) -> Optional[str]: | |
if len(retriever.docstore) == 0: | |
gr.Warning("No documents to download.") | |
return None | |
# zip the directory | |
file_path = os.path.join(tempfile.gettempdir(), file_name) | |
gr.Info(f"Zipping the retriever store to '{file_name}' ...") | |
result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") | |
return result_file_path | |
def upload_processed_documents( | |
file_name: str, | |
retriever: DocumentAwareSpanRetriever, | |
layer_captions: dict[str, str], | |
) -> pd.DataFrame: | |
# load the documents from the zip file or directory | |
retriever.load_from_disc(file_name) | |
# return the overview of the document store | |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
def process_text_from_arxiv( | |
arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs | |
) -> str: | |
try: | |
text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) | |
except Exception as e: | |
raise gr.Error(f"Failed to load text from arXiv: {e}") | |
return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) | |
def render_annotated_document( | |
retriever: DocumentAwareSpanRetrieverWithRelations, | |
document_id: str, | |
render_with: str, | |
render_kwargs_json: str, | |
highlight_span_ids: Optional[List[str]] = None, | |
) -> str: | |
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( | |
retriever=retriever, document_id=document_id | |
) | |
render_kwargs = json.loads(render_kwargs_json) | |
if render_with == RENDER_WITH_PRETTY_TABLE: | |
html = render_pretty_table( | |
text=text, | |
spans=spans, | |
span_id2idx=span_id2idx, | |
binary_relations=relations, | |
**render_kwargs, | |
) | |
elif render_with == RENDER_WITH_DISPLACY: | |
html = render_displacy( | |
text=text, | |
spans=spans, | |
span_id2idx=span_id2idx, | |
binary_relations=relations, | |
highlight_span_ids=highlight_span_ids, | |
**render_kwargs, | |
) | |
else: | |
raise ValueError(f"Unknown render_with value: {render_with}") | |
return html | |