Shouvik commited on
Commit
cfcfa8a
Β·
1 Parent(s): 43d869b

pushing app codes and dependencies...

Browse files
Files changed (2) hide show
  1. app.py +121 -0
  2. requirements.txt +4 -0
app.py CHANGED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, base64
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import (
8
+ AutoProcessor,
9
+ LayoutLMv3Model,
10
+ T5ForConditionalGeneration,
11
+ AutoTokenizer
12
+ )
13
+
14
+ # ── 1) CONFIG & CHECKPOINT ────────────────────────────────────────────────
15
+ HF_REPO = "shouvik27/LayoutLMv3_T5"
16
+ CKPT_NAME = "model.bin"
17
+
18
+ # 1a) Download the checkpoint dict from your Hub
19
+ ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
20
+ ckpt = torch.load(ckpt_path, map_location="cpu")
21
+
22
+ # ── 2) BUILD MODELS ───────────────────────────────────────────────────────
23
+ # 2a) Processor for LayoutLMv3
24
+ processor = AutoProcessor.from_pretrained(
25
+ "microsoft/layoutlmv3-base", apply_ocr=False
26
+ )
27
+
28
+ # 2b) LayoutLMv3 encoder
29
+ layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
30
+ layout_model.load_state_dict(ckpt["layout_model"], strict=False)
31
+ layout_model.eval().to("cpu")
32
+
33
+ # 2c) T5 decoder + tokenizer
34
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
35
+ t5_model.load_state_dict(ckpt["t5_model"], strict=False)
36
+ t5_model.eval().to("cpu")
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained("t5-small")
39
+
40
+ # 2d) Projection head
41
+ proj_state = ckpt["projection"]
42
+ projection = torch.nn.Sequential(
43
+ torch.nn.Linear(768, t5_model.config.d_model),
44
+ torch.nn.LayerNorm(t5_model.config.d_model),
45
+ torch.nn.GELU()
46
+ )
47
+ projection.load_state_dict(proj_state)
48
+ projection.eval().to("cpu")
49
+
50
+ # 2e) Ensure we have a valid start token for generation
51
+ if t5_model.config.decoder_start_token_id is None:
52
+ t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
53
+ if t5_model.config.bos_token_id is None:
54
+ t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
55
+
56
+ # ── 3) INFERENCE ─────────────────────────────────────────────────────────
57
+ def infer(image_path, json_file):
58
+ img_name = os.path.basename(image_path)
59
+
60
+ # 3a) Read the uploaded NDJSON & find the matching record
61
+ entry = None
62
+ with open(json_file.name, "r", encoding="utf-8") as f:
63
+ for line in f:
64
+ line = line.strip()
65
+ if not line:
66
+ continue
67
+ obj = json.loads(line)
68
+ if obj.get("img_name") == img_name:
69
+ entry = obj
70
+ break
71
+
72
+ if entry is None:
73
+ return f"❌ No JSON entry for: {img_name}"
74
+
75
+ words = entry["src_word_list"]
76
+ boxes = entry["src_wordbox_list"]
77
+
78
+ # 3b) Preprocess: image + OCR tokens + boxes
79
+ img = Image.open(image_path).convert("RGB")
80
+ enc = processor([img], [words], boxes=[boxes],
81
+ return_tensors="pt", padding=True, truncation=True)
82
+ pixel_values = enc.pixel_values.to("cpu")
83
+ input_ids = enc.input_ids.to("cpu")
84
+ attention_mask = enc.attention_mask.to("cpu")
85
+ bbox = enc.bbox.to("cpu")
86
+
87
+ # 3c) Forward pass
88
+ with torch.no_grad():
89
+ out = layout_model(
90
+ pixel_values=pixel_values,
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ bbox=bbox
94
+ )
95
+ seq_len = input_ids.size(1)
96
+ text_feats = out.last_hidden_state[:, :seq_len, :]
97
+ proj_feats = projection(text_feats)
98
+
99
+ gen_ids = t5_model.generate(
100
+ inputs_embeds=proj_feats,
101
+ attention_mask=attention_mask,
102
+ max_length=512,
103
+ decoder_start_token_id=t5_model.config.decoder_start_token_id
104
+ )
105
+
106
+ # 3d) Decode & return
107
+ return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
108
+
109
+ # ── 4) GRADIO APP ────────────────────────────────────────────────────────
110
+ demo = gr.Interface(
111
+ fn=infer,
112
+ inputs=[
113
+ gr.Image(type="filepath", label="Upload Image"),
114
+ gr.File(label="Upload JSON (NDJSON)")
115
+ ],
116
+ outputs="text",
117
+ title="OCR Reorder Pipeline"
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio