ocr-reorder-space / inference.py
Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
93dce4d
raw
history blame
2.47 kB
import torch
from transformers import Pipeline
from PIL import Image
import base64
from io import BytesIO
from huggingface_hub import hf_hub_download
# HF model repo containing pytorch_model.bin with 'projection' state
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
class OcrReorderPipeline(Pipeline):
def __init__(self, model, tokenizer, processor, device=0):
super().__init__(model=model, tokenizer=tokenizer,
feature_extractor=processor, device=device)
# ── Download your fine-tuned checkpoint ───────────────────────────
ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_path, map_location="cpu")
proj_state= ckpt["projection"]
# ── Rebuild & load your projection head (T5-small hidden size = 512) ─
d_model = 512
self.projection = torch.nn.Sequential(
torch.nn.Linear(768, d_model),
torch.nn.LayerNorm(d_model),
torch.nn.GELU()
)
self.projection.load_state_dict(proj_state)
self.projection.to(self.device)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, image, words, boxes):
data = base64.b64decode(image)
img = Image.open(BytesIO(data)).convert("RGB")
return self.feature_extractor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
def _forward(self, model_inputs):
pv, ids, mask, bbox = (
model_inputs[k].to(self.device)
for k in ("pixel_values","input_ids","attention_mask","bbox")
)
vision_out = self.model.vision_model(
pixel_values=pv,
input_ids=ids,
attention_mask=mask,
bbox=bbox
)
seq_len = ids.size(1)
text_feats = vision_out.last_hidden_state[:, :seq_len, :]
proj_feats = self.projection(text_feats)
gen_ids = self.model.text_model.generate(
inputs_embeds=proj_feats,
attention_mask=mask,
max_length=512
)
return {"generated_ids": gen_ids}
def postprocess(self, model_outputs):
return self.tokenizer.batch_decode(
model_outputs["generated_ids"],
skip_special_tokens=True
)