olmocr-demo / app.py
leonarb's picture
Fixes math/headers/tables/etc...
2a16ca6 verified
raw
history blame
5.63 kB
import gradio as gr
import torch
import base64
import fitz # PyMuPDF
import tempfile
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from latex2mathml.converter import convert as latex_to_mathml
import markdown2
import html
import json
import re
# Load model and processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def convert_latex(text):
def replacer(match):
try:
return f"<math>{latex_to_mathml(match.group(1))}</math>"
except:
return html.escape(match.group(0))
text = re.sub(r'\\\((.*?)\\\)', replacer, text)
text = re.sub(r'\\\[(.*?)\\\]', replacer, text)
return text
def stitch_paragraphs(pages):
joined = "\n".join(pages)
return re.sub(r"(?<!\n)\n(?!\n)", " ", joined) # Join lines not separated by double newline
def process_pdf_to_html(pdf_file, title, author):
pdf_path = pdf_file.name
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Extract TOC
toc_entries = doc.get_toc()
toc_by_page = {}
for level, text, page in toc_entries:
toc_by_page.setdefault(page, []).append((level, text))
pages_output = []
cover_img_html = ""
for i in range(num_pages):
page_num = i + 1
print(f"Processing page {page_num}...")
try:
image_base64 = render_pdf_to_base64png(pdf_path, page_num, target_longest_image_dim=1024)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport", target_length=4000)
prompt = (
"Below is the image of one page of a document, as well as some raw textual content that was previously "
"extracted for it. Just return the plain text representation of this document as if you were reading it naturally.\n"
"Do not hallucinate.\n"
"RAW_TEXT_START\n"
f"{anchor_text}\n"
"RAW_TEXT_END"
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=5096,
num_return_sequences=1,
do_sample=True,
)
prompt_len = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_len:].detach().cpu()
decoded = ""
if new_tokens.shape[1] > 0:
try:
raw = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
parsed = json.loads(raw)
decoded = parsed.get("natural_text", raw)
except:
decoded = raw
except Exception as e:
decoded = f"[Error on page {page_num}: {str(e)}]"
# Save first image as cover
if page_num == 1:
cover_img_html = f'<img src="data:image/png;base64,{image_base64}" alt="cover" style="max-width:100%; height:auto;"><hr>'
# Add TOC-based headers if any
header_html = ""
if page_num in toc_by_page:
for level, header in toc_by_page[page_num]:
tag = f"h{min(level, 6)}"
header_html += f"<{tag}>{html.escape(header)}</{tag}>\n"
pages_output.append(f"{header_html}\n{decoded}")
# Join paragraphs across pages
stitched = stitch_paragraphs(pages_output)
mathml = convert_latex(stitched)
rendered = markdown2.markdown(mathml)
html_doc = f"""<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'>
<title>{html.escape(title)}</title>
</head>
<body>
<h1>{html.escape(title)}</h1>
<h3>{html.escape(author)}</h3>
{cover_img_html}
{rendered}
</body>
</html>
"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".html", dir="/tmp", mode="w", encoding="utf-8") as tmp:
tmp.write(html_doc)
return tmp.name
iface = gr.Interface(
fn=process_pdf_to_html,
inputs=[
gr.File(label="Upload PDF", file_types=[".pdf"]),
gr.Textbox(label="HTML Title"),
gr.Textbox(label="Author(s)")
],
outputs=gr.File(label="Download HTML"),
title="PDF to HTML Converter (Refined with olmOCR)",
description="Uploads a PDF, extracts text via vision+prompt, stitches paragraphs, adds headers, and converts math and markdown to styled HTML.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True,
allowed_paths=["/tmp"]
)