olmocr-demo / app.py
leonarb's picture
Update app.py
8be5494 verified
raw
history blame
2.51 kB
import torch
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
# Load processor and model
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained(
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def process_pdf(file, page=1):
# Save uploaded file to disk
file_path = file.name
# Render the selected PDF page to base64 PNG
image_base64 = render_pdf_to_base64png(file_path, page, target_longest_image_dim=1024)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
# Extract document metadata and build the prompt
anchor_text = get_anchor_text(file_path, page, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
# Construct chat message
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Tokenize inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[main_image], return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run model
with torch.no_grad():
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=256,
num_return_sequences=1,
do_sample=True,
)
# Decode
prompt_len = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_len:]
decoded = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
return decoded[0]
# Gradio interface
iface = gr.Interface(
fn=process_pdf,
inputs=[
gr.File(label="Upload PDF"),
gr.Number(value=1, label="Page Number"),
],
outputs="text",
title="olmOCR PDF Text Extractor",
description="Upload a PDF and select a page to extract text using the olmOCR model.",
)
if __name__ == "__main__":
iface.launch()