Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
0d4b0fc
raw
history blame
5.18 kB
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
# ── 1) MODEL SETUP ─────────────────────────────────────────────────────
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
# Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
repo,
subfolder="preprocessor",
apply_ocr=False
)
# LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained(repo)
layout_model.eval()
# T5 decoder & tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained(repo)
t5_model.eval()
tokenizer = AutoTokenizer.from_pretrained(
repo, subfolder="preprocessor"
)
# Ensure decoder_start_token_id is set
if t5_model.config.decoder_start_token_id is None:
# Fallback to bos_token_id if present
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id
# Projection head: load from checkpoint
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_file, map_location="cpu")
proj_state= ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval()
# Move models to CPU (Spaces are CPU-only)
device = torch.device("cpu")
layout_model.to(device)
t5_model.to(device)
projection.to(device)
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
# Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
repo,
subfolder="preprocessor",
apply_ocr=False
)
# LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained(repo)
layout_model.eval()
# T5 decoder & tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained(repo)
t5_model.eval()
tokenizer = AutoTokenizer.from_pretrained(
repo, subfolder="preprocessor"
)
# Projection head: load from checkpoint
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_file, map_location="cpu")
proj_state= ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval()
# Move models to CPU (Spaces are CPU-only)
device = torch.device("cpu")
layout_model.to(device)
t5_model.to(device)
projection.to(device)
# ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# 2.a) Load NDJSON file (one JSON object per line)
data = []
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data.append(json.loads(line))
# 2.b) Find entry matching uploaded image
entry = next((e for e in data if e.get("img_name") == img_name), None)
if entry is None:
return f"❌ No JSON entry found for image '{img_name}'"
words = entry.get("src_word_list", [])
boxes = entry.get("src_wordbox_list", [])
# 2.c) Open and preprocess the image + tokens + boxes
img = Image.open(image_path).convert("RGB")
encoding = processor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
pixel_values = encoding.pixel_values.to(device)
input_ids = encoding.input_ids.to(device)
attention_mask = encoding.attention_mask.to(device)
bbox = encoding.bbox.to(device)
# 2.d) Forward pass
with torch.no_grad():
# LayoutLMv3 encoding
lm_out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = lm_out.last_hidden_state[:, :seq_len, :]
# Projection β†’ T5 decoding
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
# Decode to text
result = tokenizer.batch_decode(
gen_ids, skip_special_tokens=True
)[0]
return result
# ── 3) GRADIO UI ───────────────────────────────────────────────────────
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
demo.launch(share=True)