Spaces:
Sleeping
Sleeping
| import os, json, base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import ( | |
| AutoProcessor, | |
| LayoutLMv3Model, | |
| T5ForConditionalGeneration, | |
| AutoTokenizer | |
| ) | |
| # ββ 1) CONFIG & CHECKPOINT ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_REPO = "shouvik27/LayoutLMv3_T5" | |
| CKPT_NAME = "model.bin" | |
| # 1a) Download the checkpoint dict from your Hub | |
| ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME) | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| # ββ 2) BUILD MODELS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2a) Processor for LayoutLMv3 | |
| processor = AutoProcessor.from_pretrained( | |
| "microsoft/layoutlmv3-base", apply_ocr=False | |
| ) | |
| # 2b) LayoutLMv3 encoder | |
| layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base") | |
| layout_model.load_state_dict(ckpt["layout_model"], strict=False) | |
| layout_model.eval().to("cpu") | |
| # 2c) T5 decoder + tokenizer | |
| t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
| t5_model.load_state_dict(ckpt["t5_model"], strict=False) | |
| t5_model.eval().to("cpu") | |
| tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
| # 2d) Projection head | |
| proj_state = ckpt["projection"] | |
| projection = torch.nn.Sequential( | |
| torch.nn.Linear(768, t5_model.config.d_model), | |
| torch.nn.LayerNorm(t5_model.config.d_model), | |
| torch.nn.GELU() | |
| ) | |
| projection.load_state_dict(proj_state) | |
| projection.eval().to("cpu") | |
| # 2e) Ensure we have a valid start token for generation | |
| if t5_model.config.decoder_start_token_id is None: | |
| t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id | |
| if t5_model.config.bos_token_id is None: | |
| t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id | |
| # ββ 3) INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def infer(image_path, json_file): | |
| img_name = os.path.basename(image_path) | |
| # 3a) Read the uploaded NDJSON & find the matching record | |
| entry = None | |
| with open(json_file.name, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| obj = json.loads(line) | |
| if obj.get("img_name") == img_name: | |
| entry = obj | |
| break | |
| if entry is None: | |
| return f"β No JSON entry for: {img_name}" | |
| words = entry["src_word_list"] | |
| boxes = entry["src_wordbox_list"] | |
| # 3b) Preprocess: image + OCR tokens + boxes | |
| img = Image.open(image_path).convert("RGB") | |
| enc = processor([img], [words], boxes=[boxes], | |
| return_tensors="pt", padding=True, truncation=True) | |
| pixel_values = enc.pixel_values.to("cpu") | |
| input_ids = enc.input_ids.to("cpu") | |
| attention_mask = enc.attention_mask.to("cpu") | |
| bbox = enc.bbox.to("cpu") | |
| # 3c) Forward pass | |
| with torch.no_grad(): | |
| out = layout_model( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| bbox=bbox | |
| ) | |
| seq_len = input_ids.size(1) | |
| text_feats = out.last_hidden_state[:, :seq_len, :] | |
| proj_feats = projection(text_feats) | |
| gen_ids = t5_model.generate( | |
| inputs_embeds=proj_feats, | |
| attention_mask=attention_mask, | |
| max_length=512, | |
| decoder_start_token_id=t5_model.config.decoder_start_token_id | |
| ) | |
| # 3d) Decode & return | |
| return tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
| # ββ 4) GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Upload Image"), | |
| gr.File(label="Upload JSON (NDJSON)") | |
| ], | |
| outputs="text", | |
| title="OCR Reorder Pipeline" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |