ocr-dimt / app.py
Shouvik
pushing the changes..
1b470fb
import os
import json
from io import BytesIO
from PIL import Image
import torch
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
app = FastAPI()
# ── 1) CONFIG & CHECKPOINT ────────────────────────────────────────────────
HF_REPO = "shouvik27/LayoutLMv3_T5"
CKPT_NAME = "pytorch_model.bin"
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt = torch.load(ckpt_path, map_location="cpu")
# ── 2) BUILD MODELS ───────────────────────────────────────────────────────
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")
if t5_model.config.decoder_start_token_id is None:
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# ── 3) INFERENCE ─────────────────────────────────────────────────────────
def infer_from_files(image_file: UploadFile, json_file: UploadFile):
# Read image
image_bytes = image_file.file.read()
img_name = os.path.basename(image_file.filename)
# Parse the NDJSON file, find entry
entry = None
for line in json_file.file:
if not line.strip():
continue
obj = json.loads(line.decode('utf-8').strip())
if obj.get("img_name") == img_name:
entry = obj
break
if entry is None:
return {"error": f"No JSON entry for: {img_name}"}
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
img = Image.open(BytesIO(image_bytes)).convert("RGB")
enc = processor([img], [words], boxes=[boxes], return_tensors="pt", padding=True, truncation=True)
pixel_values = enc.pixel_values.to("cpu")
input_ids = enc.input_ids.to("cpu")
attention_mask = enc.attention_mask.to("cpu")
bbox = enc.bbox.to("cpu")
with torch.no_grad():
out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = out.last_hidden_state[:, :seq_len, :]
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
result = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"result": result}
# ── 4) FASTAPI ENDPOINT ──────────────────────────────────────────────────
@app.post("/infer")
async def infer_api(
image_file: UploadFile = File(..., description="The image file"),
json_file: UploadFile = File(..., description="The NDJSON file"),
):
output = infer_from_files(image_file, json_file)
return JSONResponse(content=output)
@app.get("/")
def healthcheck():
return {"message": "OCR FastAPI server is running."}