File size: 4,398 Bytes
cfcfa8a 9e38983 cfcfa8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os, json, base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
# ββ 1) CONFIG & CHECKPOINT ββββββββββββββββββββββββββββββββββββββββββββββββ
HF_REPO = "shouvik27/LayoutLMv3_T5"
CKPT_NAME = "pytorch_model.bin"
# 1a) Download the checkpoint dict from your Hub
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt = torch.load(ckpt_path, map_location="cpu")
# ββ 2) BUILD MODELS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2a) Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
# 2b) LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")
# 2c) T5 decoder + tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
# 2d) Projection head
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")
# 2e) Ensure we have a valid start token for generation
if t5_model.config.decoder_start_token_id is None:
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# ββ 3) INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# 3a) Read the uploaded NDJSON & find the matching record
entry = None
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
if obj.get("img_name") == img_name:
entry = obj
break
if entry is None:
return f"β No JSON entry for: {img_name}"
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
# 3b) Preprocess: image + OCR tokens + boxes
img = Image.open(image_path).convert("RGB")
enc = processor([img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True)
pixel_values = enc.pixel_values.to("cpu")
input_ids = enc.input_ids.to("cpu")
attention_mask = enc.attention_mask.to("cpu")
bbox = enc.bbox.to("cpu")
# 3c) Forward pass
with torch.no_grad():
out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = out.last_hidden_state[:, :seq_len, :]
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
# 3d) Decode & return
return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# ββ 4) GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
demo.launch(share=True) |