Christof Bless
first working mvp
b23f8b6 unverified
import gradio as gr
import numpy as np
import pymupdf4llm
import spacy
from transformers import AutoTokenizer, AutoModel
from adapters import AutoAdapterModel
from extract_citations import fetch_citations_for_dois
from extract_embeddings import (
prune_contexts,
embed_abstracts,
embed_contexts,
restore_inverted_abstract,
calculate_distances
)
from extract_mentions import extract_citation_contexts
def extract_text(pdf_file):
if not pdf_file:
return "Please upload a PDF file."
try:
return pymupdf4llm.to_markdown(pdf_file)
except Exception as e:
return f"Error when processing PDF. {e}"
def extract_citations(doi):
try:
citations_data = fetch_citations_for_dois([doi])
except Exception as e:
return f"Please submit a valid DOI. {e}"
return citations_data
def get_cite_context_distance(pdf, doi):
# Load models
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
nlp = spacy.load("en_core_web_sm")
# fetch cited papers from OpenAlex
citations_data = fetch_citations_for_dois([doi])
# get markdown text from PDF file
text = extract_text(pdf.name)
# get the context around citation markers
citations = extract_citation_contexts(citations_data, text)
citations["pruned_contexts"], citations["known_tokens_fraction"] = prune_contexts(citations, nlp, tokenizer)
# embed the contexts
citation_context_embedding = embed_contexts(
citations[
(citations["known_tokens_fraction"] >= 0.7) &
(~citations["pruned_contexts"].isna())
]["pruned_contexts"].to_list(),
model,
tokenizer,
).detach().numpy()
citations_data = {entry["id"]:entry for cite in citations_data.values() for entry in cite}
# embed the abstract
citation_abstract_embedding = embed_abstracts(
[
{
"title":citations_data[cite]["title"],
"abstract": (
restore_inverted_abstract(
citations_data[cite]["abstract_inverted_index"]
)
if citations_data[cite]["abstract_inverted_index"] is not None
else None
)
}
for cite in citations["citation_id"].unique()
],
model,
tokenizer,
batch_size=4,
).detach().numpy()
print(citation_abstract_embedding.shape)
# calculate the distances
index_left = citations.index[
(citations["known_tokens_fraction"] >= 0.7) &
(~citations["pruned_contexts"].isna())
].tolist()
index_right = citations["citation_id"].unique().tolist()
indices = [
(index_left.index(i), index_right.index(cite_id))
if i in index_left else (None, None)
for i, cite_id in enumerate(citations["citation_id"])
]
distances = np.array(calculate_distances(citation_context_embedding, citation_abstract_embedding, indices))
results = []
for i, dist in enumerate(distances):
if not np.isnan(dist):
obj = {}
left_context = citations.left_context[i][-50:].replace('\n', '')
right_context = citations.right_context[i][:50].replace('\n', '')
obj["cite_context_short"] = f"...{left_context}{citations.mention[i]}{right_context}..."
obj["cited_paper"] = citations_data[citations.citation_id[i]]["title"]
obj["cited_paper_id"] = citations.citation_id[i]
obj["distance"] = dist
results.append(obj)
return {"score": np.nanmean(distances), "individual_citations": results}
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Citation Integrity Score")
doi_input = gr.Textbox(label="Enter DOI (optional)")
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
output = gr.Textbox(label="Extracted Citations", lines=20)
submit_btn = gr.Button("Submit")
submit_btn.click(fn=get_cite_context_distance, inputs=[pdf_input, doi_input], outputs=output)
demo.launch()