Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
4617aca
raw
history blame
2.26 kB
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
from inference import OcrReorderPipeline
from transformers import AutoProcessor, LayoutLMv3Model, AutoTokenizer
# ── 1) Load model + tokenizer + processor ─────────────────────────
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
model = LayoutLMv3Model.from_pretrained(repo)
tokenizer = AutoTokenizer.from_pretrained(repo, subfolder="preprocessor")
processor = AutoProcessor.from_pretrained(repo, subfolder="preprocessor", apply_ocr=False)
pipe = OcrReorderPipeline(model, tokenizer, processor, device=0)
# ── 2) Inference function ──────────────────────────────────────────
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# Parse NDJSON entries from uploaded file
data = []
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data.append(json.loads(line))
# Find matching entry for this image
entry = next((e for e in data if e["img_name"] == img_name), None)
if entry is None:
return f"❌ No JSON entry found for image '{img_name}'"
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
# Read and encode image to base64
img = Image.open(image_path).convert("RGB")
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
# Call pipeline with `inputs` keyword plus extra args
reordered = pipe(inputs=b64, words=words, boxes=boxes)[0]
return reordered
# ── 3) Gradio interface ─────────────────────────────────────────────
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
# set share=True if you want a public link
demo.launch()