olmocr-demo / app.py
leonarb's picture
Update app.py
822eba7 verified
raw
history blame
5.06 kB
import gradio as gr
import torch
import base64
import fitz # PyMuPDF
from io import BytesIO
from PIL import Image
from pathlib import Path
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from ebooklib import epub
# Load model and processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def process_pdf_to_epub(pdf_file, title, author):
pdf_path = pdf_file.name
doc = fitz.open(pdf_path)
num_pages = len(doc)
book = epub.EpubBook()
book.set_identifier("id123456")
book.set_title(title)
book.add_author(author)
chapters = []
for i in range(num_pages):
page_num = i + 1
print(f"Processing page {page_num}...")
try:
image_base64 = render_pdf_to_base64png(pdf_path, page_num, target_longest_image_dim=1024)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport", target_length=4000)
prompt = (
"Below is the image of one page of a document, as well as some raw textual content that was previously "
"extracted for it. Just return the plain text representation of this document as if you were reading it naturally.\n"
"Do not hallucinate.\n"
"RAW_TEXT_START\n"
f"{anchor_text}\n"
"RAW_TEXT_END"
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=512,
num_return_sequences=1,
do_sample=True,
)
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:].detach().cpu()
decoded = "[No output generated]"
if new_tokens is not None and new_tokens.shape[1] > 0:
try:
decoded_list = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
decoded = decoded_list[0].strip() if decoded_list else "[No output generated]"
except Exception as decode_error:
decoded = f"[Decoding error on page {page_num}: {str(decode_error)}]"
else:
decoded = "[Model returned no new tokens]"
except Exception as processing_error:
decoded = f"[Processing error on page {page_num}: {str(processing_error)}]"
print(f"Decoded content for page {page_num}: {decoded}")
chapter = epub.EpubHtml(title=f"Page {page_num}", file_name=f"page_{page_num}.xhtml", lang="en")
chapter.content = f"<h1>Page {page_num}</h1><p>{decoded}</p>"
book.add_item(chapter)
chapters.append(chapter)
if page_num == 1:
cover_image = Image.open(BytesIO(base64.b64decode(image_base64)))
cover_io = BytesIO()
cover_image.save(cover_io, format='PNG')
book.set_cover("cover.png", cover_io.getvalue())
book.toc = tuple(chapters)
book.add_item(epub.EpubNcx())
book.add_item(epub.EpubNav())
book.spine = ['nav'] + chapters
# βœ… SAFELY write to a temp file in /tmp
with tempfile.NamedTemporaryFile(delete=False, suffix=".epub", dir="/tmp") as tmp:
epub.write_epub(tmp.name, book)
return tmp.name
# Gradio Interface
iface = gr.Interface(
fn=process_pdf_to_epub,
inputs=[
gr.File(label="Upload PDF", file_types=[".pdf"]),
gr.Textbox(label="EPUB Title"),
gr.Textbox(label="Author(s)")
],
outputs=gr.File(label="Download EPUB"),
title="PDF to EPUB Converter (with olmOCR)",
description="Uploads a PDF, extracts text from each page with vision + prompt, and builds an EPUB using the outputs. Sets the first page as cover.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True,
allowed_paths=["/tmp"]
)