Spaces:
Running
Running
import gradio as gr | |
import torch | |
import base64 | |
import fitz # PyMuPDF | |
import tempfile | |
from io import BytesIO | |
from PIL import Image | |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
from olmocr.data.renderpdf import render_pdf_to_base64png | |
from olmocr.prompts.anchor import get_anchor_text | |
from latex2mathml.converter import convert as latex_to_mathml | |
import markdown2 | |
import html | |
import json | |
import re | |
# Load model and processor | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16 | |
).eval() | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def convert_latex(text): | |
def replacer(match): | |
try: | |
return f"<math>{latex_to_mathml(match.group(1))}</math>" | |
except: | |
return html.escape(match.group(0)) | |
text = re.sub(r'\\\((.*?)\\\)', replacer, text) | |
text = re.sub(r'\\\[(.*?)\\\]', replacer, text) | |
return text | |
def stitch_paragraphs(pages): | |
joined = "\n".join(pages) | |
return re.sub(r"(?<!\n)\n(?!\n)", " ", joined) # Join lines not separated by double newline | |
def process_pdf_to_html(pdf_file, title, author): | |
pdf_path = pdf_file.name | |
doc = fitz.open(pdf_path) | |
num_pages = len(doc) | |
# Extract TOC | |
toc_entries = doc.get_toc() | |
toc_by_page = {} | |
for level, text, page in toc_entries: | |
toc_by_page.setdefault(page, []).append((level, text)) | |
pages_output = [] | |
cover_img_html = "" | |
for i in range(num_pages): | |
page_num = i + 1 | |
print(f"Processing page {page_num}...") | |
try: | |
image_base64 = render_pdf_to_base64png(pdf_path, page_num, target_longest_image_dim=1024) | |
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport", target_length=4000) | |
prompt = ( | |
"Below is the image of one page of a document, as well as some raw textual content that was previously " | |
"extracted for it. Just return the plain text representation of this document as if you were reading it naturally.\n" | |
"Do not hallucinate.\n" | |
"RAW_TEXT_START\n" | |
f"{anchor_text}\n" | |
"RAW_TEXT_END" | |
) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}, | |
], | |
} | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image = Image.open(BytesIO(base64.b64decode(image_base64))) | |
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
output = model.generate( | |
**inputs, | |
temperature=0.8, | |
max_new_tokens=5096, | |
num_return_sequences=1, | |
do_sample=True, | |
) | |
prompt_len = inputs["input_ids"].shape[1] | |
new_tokens = output[:, prompt_len:].detach().cpu() | |
decoded = "" | |
if new_tokens.shape[1] > 0: | |
try: | |
raw = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip() | |
parsed = json.loads(raw) | |
decoded = parsed.get("natural_text", raw) | |
except: | |
decoded = raw | |
except Exception as e: | |
decoded = f"[Error on page {page_num}: {str(e)}]" | |
# Save first image as cover | |
if page_num == 1: | |
cover_img_html = f'<img src="data:image/png;base64,{image_base64}" alt="cover" style="max-width:100%; height:auto;"><hr>' | |
# Add TOC-based headers if any | |
header_html = "" | |
if page_num in toc_by_page: | |
for level, header in toc_by_page[page_num]: | |
tag = f"h{min(level, 6)}" | |
header_html += f"<{tag}>{html.escape(header)}</{tag}>\n" | |
pages_output.append(f"{header_html}\n{decoded}") | |
# Join paragraphs across pages | |
stitched = stitch_paragraphs(pages_output) | |
mathml = convert_latex(stitched) | |
rendered = markdown2.markdown(mathml) | |
html_doc = f"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset='utf-8'> | |
<title>{html.escape(title)}</title> | |
</head> | |
<body> | |
<h1>{html.escape(title)}</h1> | |
<h3>{html.escape(author)}</h3> | |
{cover_img_html} | |
{rendered} | |
</body> | |
</html> | |
""" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".html", dir="/tmp", mode="w", encoding="utf-8") as tmp: | |
tmp.write(html_doc) | |
return tmp.name | |
iface = gr.Interface( | |
fn=process_pdf_to_html, | |
inputs=[ | |
gr.File(label="Upload PDF", file_types=[".pdf"]), | |
gr.Textbox(label="HTML Title"), | |
gr.Textbox(label="Author(s)") | |
], | |
outputs=gr.File(label="Download HTML"), | |
title="PDF to HTML Converter (Refined with olmOCR)", | |
description="Uploads a PDF, extracts text via vision+prompt, stitches paragraphs, adds headers, and converts math and markdown to styled HTML.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
debug=True, | |
allowed_paths=["/tmp"] | |
) | |