File size: 4,259 Bytes
1b470fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import json
from io import BytesIO
from PIL import Image
import torch
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
app = FastAPI()
# ββ 1) CONFIG & CHECKPOINT ββββββββββββββββββββββββββββββββββββββββββββββββ
HF_REPO = "shouvik27/LayoutLMv3_T5"
CKPT_NAME = "pytorch_model.bin"
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt = torch.load(ckpt_path, map_location="cpu")
# ββ 2) BUILD MODELS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")
if t5_model.config.decoder_start_token_id is None:
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# ββ 3) INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def infer_from_files(image_file: UploadFile, json_file: UploadFile):
# Read image
image_bytes = image_file.file.read()
img_name = os.path.basename(image_file.filename)
# Parse the NDJSON file, find entry
entry = None
for line in json_file.file:
if not line.strip():
continue
obj = json.loads(line.decode('utf-8').strip())
if obj.get("img_name") == img_name:
entry = obj
break
if entry is None:
return {"error": f"No JSON entry for: {img_name}"}
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
img = Image.open(BytesIO(image_bytes)).convert("RGB")
enc = processor([img], [words], boxes=[boxes], return_tensors="pt", padding=True, truncation=True)
pixel_values = enc.pixel_values.to("cpu")
input_ids = enc.input_ids.to("cpu")
attention_mask = enc.attention_mask.to("cpu")
bbox = enc.bbox.to("cpu")
with torch.no_grad():
out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = out.last_hidden_state[:, :seq_len, :]
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
result = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"result": result}
# ββ 4) FASTAPI ENDPOINT ββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.post("/infer")
async def infer_api(
image_file: UploadFile = File(..., description="The image file"),
json_file: UploadFile = File(..., description="The NDJSON file"),
):
output = infer_from_files(image_file, json_file)
return JSONResponse(content=output)
@app.get("/")
def healthcheck():
return {"message": "OCR FastAPI server is running."} |