Uddipan Basu Bir commited on
Commit
419d02f
Β·
1 Parent(s): 4956f20

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. app.py +88 -26
app.py CHANGED
@@ -4,49 +4,112 @@ import base64
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- from inference import OcrReorderPipeline
9
- from transformers import AutoProcessor, LayoutLMv3Model, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # ── 1) Load model + tokenizer + processor ─────────────────────────
12
- repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
13
- model = LayoutLMv3Model.from_pretrained(repo)
14
- tokenizer = AutoTokenizer.from_pretrained(repo, subfolder="preprocessor")
15
- processor = AutoProcessor.from_pretrained(repo, subfolder="preprocessor", apply_ocr=False)
16
- pipe = OcrReorderPipeline(model, tokenizer, processor, device=0)
17
 
18
- # ── 2) Inference function ──────────────────────────────────────────
19
  def infer(image_path, json_file):
20
  img_name = os.path.basename(image_path)
21
 
22
- # Parse NDJSON entries from uploaded file
23
  data = []
24
  with open(json_file.name, "r", encoding="utf-8") as f:
25
  for line in f:
26
- line = line.strip()
27
- if not line:
28
  continue
29
  data.append(json.loads(line))
30
 
31
- # Find matching entry for this image
32
- entry = next((e for e in data if e["img_name"] == img_name), None)
33
  if entry is None:
34
  return f"❌ No JSON entry found for image '{img_name}'"
35
 
36
- words = entry["src_word_list"]
37
- boxes = entry["src_wordbox_list"]
38
 
39
- # Read and encode image to base64
40
  img = Image.open(image_path).convert("RGB")
41
- buf = BytesIO()
42
- img.save(buf, format="PNG")
43
- b64 = base64.b64encode(buf.getvalue()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Call pipeline with `inputs` keyword plus extra args
46
- reordered = pipe(inputs=b64, words=words, boxes=boxes)[0]
47
- return reordered
 
 
48
 
49
- # ── 3) Gradio interface ─────────────────────────────────────────────
50
  demo = gr.Interface(
51
  fn=infer,
52
  inputs=[
@@ -58,5 +121,4 @@ demo = gr.Interface(
58
  )
59
 
60
  if __name__ == "__main__":
61
- # set share=True if you want a public link
62
- demo.launch()
 
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+ from transformers import (
10
+ AutoProcessor,
11
+ LayoutLMv3Model,
12
+ T5ForConditionalGeneration,
13
+ AutoTokenizer
14
+ )
15
+
16
+ # ── 1) MODEL SETUP ─────────────────────────────────────────────────────
17
+ repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
18
+
19
+ # Processor for LayoutLMv3
20
+ processor = AutoProcessor.from_pretrained(
21
+ repo,
22
+ subfolder="preprocessor",
23
+ apply_ocr=False
24
+ )
25
+
26
+ # LayoutLMv3 encoder
27
+ layout_model = LayoutLMv3Model.from_pretrained(repo)
28
+ layout_model.eval()
29
 
30
+ # T5 decoder & tokenizer
31
+ t5_model = T5ForConditionalGeneration.from_pretrained(repo)
32
+ t5_model.eval()
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ repo, subfolder="preprocessor"
35
+ )
36
+
37
+ # Projection head: load from checkpoint
38
+ ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
39
+ ckpt = torch.load(ckpt_file, map_location="cpu")
40
+ proj_state= ckpt["projection"]
41
+ projection = torch.nn.Sequential(
42
+ torch.nn.Linear(768, t5_model.config.d_model),
43
+ torch.nn.LayerNorm(t5_model.config.d_model),
44
+ torch.nn.GELU()
45
+ )
46
+ projection.load_state_dict(proj_state)
47
+ projection.eval()
48
 
49
+ # Move models to CPU (Spaces are CPU-only)
50
+ device = torch.device("cpu")
51
+ layout_model.to(device)
52
+ t5_model.to(device)
53
+ projection.to(device)
 
54
 
55
+ # ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
56
  def infer(image_path, json_file):
57
  img_name = os.path.basename(image_path)
58
 
59
+ # 2.a) Load NDJSON file (one JSON object per line)
60
  data = []
61
  with open(json_file.name, "r", encoding="utf-8") as f:
62
  for line in f:
63
+ if not line.strip():
 
64
  continue
65
  data.append(json.loads(line))
66
 
67
+ # 2.b) Find entry matching uploaded image
68
+ entry = next((e for e in data if e.get("img_name") == img_name), None)
69
  if entry is None:
70
  return f"❌ No JSON entry found for image '{img_name}'"
71
 
72
+ words = entry.get("src_word_list", [])
73
+ boxes = entry.get("src_wordbox_list", [])
74
 
75
+ # 2.c) Open and preprocess the image + tokens + boxes
76
  img = Image.open(image_path).convert("RGB")
77
+ encoding = processor(
78
+ [img], [words], boxes=[boxes],
79
+ return_tensors="pt", padding=True, truncation=True
80
+ )
81
+ pixel_values = encoding.pixel_values.to(device)
82
+ input_ids = encoding.input_ids.to(device)
83
+ attention_mask = encoding.attention_mask.to(device)
84
+ bbox = encoding.bbox.to(device)
85
+
86
+ # 2.d) Forward pass
87
+ with torch.no_grad():
88
+ # LayoutLMv3 encoding
89
+ lm_out = layout_model(
90
+ pixel_values=pixel_values,
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ bbox=bbox
94
+ )
95
+ seq_len = input_ids.size(1)
96
+ text_feats = lm_out.last_hidden_state[:, :seq_len, :]
97
+
98
+ # Projection β†’ T5 decoding
99
+ proj_feats = projection(text_feats)
100
+ gen_ids = t5_model.generate(
101
+ inputs_embeds=proj_feats,
102
+ attention_mask=attention_mask,
103
+ max_length=512
104
+ )
105
 
106
+ # Decode to text
107
+ result = tokenizer.batch_decode(
108
+ gen_ids, skip_special_tokens=True
109
+ )[0]
110
+ return result
111
 
112
+ # ── 3) GRADIO UI ───────────────────��───────────────────────────────────
113
  demo = gr.Interface(
114
  fn=infer,
115
  inputs=[
 
121
  )
122
 
123
  if __name__ == "__main__":
124
+ demo.launch(share=True)