OCRArena / app.py
AkashDataScience's picture
Execution changes
1402288
raw
history blame
5.07 kB
from PyPDF2 import PdfReader
import gradio as gr
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.base_models import InputFormat
from paddleocr import PPStructureV3
from pdf2image import convert_from_path
import numpy as np
import torch
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
from pathlib import Path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_options = PdfPipelineOptions(enable_remote_services=True)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
}
)
pipeline = PPStructureV3()
processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForVision2Seq.from_pretrained(
"ds4sd/SmolDocling-256M-preview",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)
def get_pdf_page_count(pdf_path):
reader = PdfReader(pdf_path)
return len(reader.pages)
def get_page_image(pdf_path, page_num):
images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
page_image = images[0]
return page_image
def get_docling_ocr(pdf_path, page_num):
result = converter.convert(pdf_path, page_range=(page_num, page_num))
markdown_text_docling = result.document.export_to_markdown()
return markdown_text_docling
def get_paddle_ocr(pdf_path, page_num):
page_image = get_page_image(pdf_path, page_num)
output = pipeline.predict(input=np.array(page_image))
markdown_list = []
for res in output:
md_info = res.markdown
markdown_list.append(md_info)
markdown_text_paddleOCR = pipeline.concatenate_markdown_pages(markdown_list)
return markdown_text_paddleOCR
def get_smoldocling_ocr(pdf_path, page_num):
page_image = get_page_image(pdf_path, page_num)
image = load_image(page_image)
# Create input messages
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Convert this page to docling."}
]
},
]
# Prepare inputs
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(DEVICE)
generated_ids = model.generate(**inputs, max_new_tokens=8192)
prompt_length = inputs.input_ids.shape[1]
trimmed_generated_ids = generated_ids[:, prompt_length:]
doctags = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=False,
)[0].lstrip()
# Populate document
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_text_smoldocling = doc.export_to_markdown()
return markdown_text_smoldocling
title = "OCR Arena"
description = "A simple Gradio interface to extract text from PDFs and compare OCR models"
examples = [["data/amazon-10-k-2024.pdf"],
["data/goog-10-k-2023.pdf"]]
with gr.Blocks(theme=gr.themes.Glass()) as demo:
gr.Markdown(f"# {title}\n{description}")
with gr.Row():
with gr.Column():
pdf = gr.File(label="Input PDFs", file_types=[".pdf"])
@gr.render(inputs=pdf)
def show_slider(pdf_path):
if pdf_path is None:
page_num = gr.Markdown("## No Input Provided")
else:
page_count = get_pdf_page_count(pdf_path)
page_num = gr.Slider(1, page_count, value=1, step=1, label="Page Number")
with gr.Row():
clear_btn = gr.ClearButton(components=[pdf, page_num])
submit_btn = gr.Button("Submit", variant='primary')
submit_btn.click(get_page_image, inputs=[pdf, page_num], outputs=original).then(
get_docling_ocr, inputs=[pdf, page_num], outputs=docling_ocr_out).then(
get_paddle_ocr, inputs=[pdf, page_num], outputs=paddle_ocr_out).then(
get_smoldocling_ocr, inputs=[pdf, page_num], outputs=smoldocling_ocr_out)
with gr.Column():
original = gr.Image(width=640, height=640, label="Original Page", interactive=False)
docling_ocr_out = gr.Textbox(label="Docling OCR Output", type="text", interactive=False)
paddle_ocr_out = gr.Textbox(label="Paddle OCR Output", type="text", interactive=False)
smoldocling_ocr_out = gr.Textbox(label="SmolDocling OCR Output", type="text", interactive=False)
examples_obj = gr.Examples(examples=examples, inputs=[pdf])
demo.launch()