File size: 5,180 Bytes
b701d44
 
 
5b9baff
 
 
419d02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b701d44
0d4b0fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419d02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b9baff
419d02f
 
 
 
 
fabf362
419d02f
fabf362
 
 
419d02f
2ebc710
fabf362
2ebc710
419d02f
2ebc710
 
ab9088f
419d02f
 
fabf362
 
5b9baff
419d02f
 
ab9088f
419d02f
fabf362
419d02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a0deb7
419d02f
 
5a0deb7
 
419d02f
ab9088f
419d02f
 
 
 
 
5b9baff
419d02f
5b9baff
 
 
b701d44
 
5b9baff
 
b701d44
5b9baff
 
 
419d02f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
    AutoProcessor,
    LayoutLMv3Model,
    T5ForConditionalGeneration,
    AutoTokenizer
)

# ── 1) MODEL SETUP ─────────────────────────────────────────────────────
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"

# Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
    repo,
    subfolder="preprocessor",
    apply_ocr=False
)

# LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained(repo)
layout_model.eval()

# T5 decoder & tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained(repo)
t5_model.eval()
tokenizer = AutoTokenizer.from_pretrained(
    repo, subfolder="preprocessor"
)

# Ensure decoder_start_token_id is set
if t5_model.config.decoder_start_token_id is None:
    # Fallback to bos_token_id if present
    t5_model.config.decoder_start_token_id = tokenizer.bos_token_id

# Projection head: load from checkpoint
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt      = torch.load(ckpt_file, map_location="cpu")
proj_state= ckpt["projection"]
projection = torch.nn.Sequential(
    torch.nn.Linear(768, t5_model.config.d_model),
    torch.nn.LayerNorm(t5_model.config.d_model),
    torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval()

# Move models to CPU (Spaces are CPU-only)
device = torch.device("cpu")
layout_model.to(device)
t5_model.to(device)
projection.to(device)
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"

# Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
    repo,
    subfolder="preprocessor",
    apply_ocr=False
)

# LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained(repo)
layout_model.eval()

# T5 decoder & tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained(repo)
t5_model.eval()
tokenizer = AutoTokenizer.from_pretrained(
    repo, subfolder="preprocessor"
)

# Projection head: load from checkpoint
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt      = torch.load(ckpt_file, map_location="cpu")
proj_state= ckpt["projection"]
projection = torch.nn.Sequential(
    torch.nn.Linear(768, t5_model.config.d_model),
    torch.nn.LayerNorm(t5_model.config.d_model),
    torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval()

# Move models to CPU (Spaces are CPU-only)
device = torch.device("cpu")
layout_model.to(device)
t5_model.to(device)
projection.to(device)

# ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
def infer(image_path, json_file):
    img_name = os.path.basename(image_path)

    # 2.a) Load NDJSON file (one JSON object per line)
    data = []
    with open(json_file.name, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            data.append(json.loads(line))

    # 2.b) Find entry matching uploaded image
    entry = next((e for e in data if e.get("img_name") == img_name), None)
    if entry is None:
        return f"❌ No JSON entry found for image '{img_name}'"

    words = entry.get("src_word_list", [])
    boxes = entry.get("src_wordbox_list", [])

    # 2.c) Open and preprocess the image + tokens + boxes
    img = Image.open(image_path).convert("RGB")
    encoding = processor(
        [img], [words], boxes=[boxes],
        return_tensors="pt", padding=True, truncation=True
    )
    pixel_values   = encoding.pixel_values.to(device)
    input_ids      = encoding.input_ids.to(device)
    attention_mask = encoding.attention_mask.to(device)
    bbox           = encoding.bbox.to(device)

    # 2.d) Forward pass
    with torch.no_grad():
        # LayoutLMv3 encoding
        lm_out = layout_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bbox
        )
        seq_len    = input_ids.size(1)
        text_feats = lm_out.last_hidden_state[:, :seq_len, :]

        # Projection β†’ T5 decoding
        proj_feats = projection(text_feats)
        gen_ids = t5_model.generate(
            inputs_embeds=proj_feats,
            attention_mask=attention_mask,
            max_length=512,
            decoder_start_token_id=t5_model.config.decoder_start_token_id
        )

    # Decode to text
    result = tokenizer.batch_decode(
        gen_ids, skip_special_tokens=True
    )[0]
    return result

# ── 3) GRADIO UI ───────────────────────────────────────────────────────
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.File(label="Upload JSON (NDJSON)")
    ],
    outputs="text",
    title="OCR Reorder Pipeline"
)

if __name__ == "__main__":
    demo.launch(share=True)