File size: 2,509 Bytes
19918ea
8be5494
 
d45f3e7
8be5494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19918ea
8be5494
 
 
 
 
 
 
 
 
 
19918ea
8be5494
 
 
 
d45f3e7
8be5494
 
 
 
 
 
 
 
 
d45f3e7
8be5494
 
 
 
 
d45f3e7
8be5494
 
d45f3e7
8be5494
 
 
 
d45f3e7
8be5494
 
d45f3e7
 
 
8be5494
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import base64
from io import BytesIO
from PIL import Image
import gradio as gr

from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text

# Load processor and model
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def process_pdf(file, page=1):
    # Save uploaded file to disk
    file_path = file.name

    # Render the selected PDF page to base64 PNG
    image_base64 = render_pdf_to_base64png(file_path, page, target_longest_image_dim=1024)
    main_image = Image.open(BytesIO(base64.b64decode(image_base64)))

    # Extract document metadata and build the prompt
    anchor_text = get_anchor_text(file_path, page, pdf_engine="pdfreport", target_length=4000)
    prompt = build_finetuning_prompt(anchor_text)

    # Construct chat message
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
            ],
        }
    ]

    # Tokenize inputs
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[main_image], return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Run model
    with torch.no_grad():
        output = model.generate(
            **inputs,
            temperature=0.8,
            max_new_tokens=256,
            num_return_sequences=1,
            do_sample=True,
        )

    # Decode
    prompt_len = inputs["input_ids"].shape[1]
    new_tokens = output[:, prompt_len:]
    decoded = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
    return decoded[0]

# Gradio interface
iface = gr.Interface(
    fn=process_pdf,
    inputs=[
        gr.File(label="Upload PDF"),
        gr.Number(value=1, label="Page Number"),
    ],
    outputs="text",
    title="olmOCR PDF Text Extractor",
    description="Upload a PDF and select a page to extract text using the olmOCR model.",
)

if __name__ == "__main__":
    iface.launch()