File size: 5,637 Bytes
df46f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d887fd5
df46f51
 
 
 
 
0214886
df46f51
99c8757
df46f51
99c8757
df46f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0214886
 
df46f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0214886
99c8757
df46f51
 
 
 
 
99c8757
 
df46f51
 
0214886
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForImageTextToText
# from PIL import Image
# import re

# # Load SmolDocling model & processor once
# processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
# model = AutoModelForImageTextToText.from_pretrained("ds4sd/SmolDocling-256M-preview")

# def extract_fcel_values_from_image(image, prompt_text):
#     """Run SmolDocling on an image and return numeric values inside <fcel> tags."""
#     # Prepare prompt for the model
#     messages = [
#         {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
#     ]
#     prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
#     inputs = processor(text=prompt, images=[image], return_tensors="pt")

#     # Generate output
#     outputs = model.generate(**inputs, max_new_tokens=2048)
#     prompt_length = inputs.input_ids.shape[1]
#     generated = outputs[:, prompt_length:]
#     result = processor.batch_decode(generated, skip_special_tokens=False)[0]
#     clean_text = result.replace("<end_of_utterance>", "").strip()

#     # Extract only <fcel> values
#     values = re.findall(r"<fcel>([\d.]+)", clean_text)
#     values = [float(v) for v in values]  # convert to floats

#     return values, clean_text

# def compare_images(image1, image2, prompt_text):
#     # Extract fcel values from both images
#     values1, raw1 = extract_fcel_values_from_image(image1, prompt_text)
#     values2, raw2 = extract_fcel_values_from_image(image2, prompt_text)

#     # Calculate accuracy
#     if len(values1) == len(values2) and values1 == values2:
#         accuracy = 100.0
#     else:
#         matches = sum(1 for a, b in zip(values1, values2) if a == b)
#         total = max(len(values1), len(values2))
#         accuracy = (matches / total) * 100 if total > 0 else 0

#     return {
#         # "Extracted Values 1": values1,
#         # "Extracted Values 2": values2,
#         "Accuracy (%)": accuracy
#     }

# # Gradio UI
# demo = gr.Interface(
#     fn=compare_images,
#     inputs=[
#         gr.Image(type="pil", label="Upload First Table Image"),
#         gr.Image(type="pil", label="Upload Second Table Image"),
#         gr.Textbox(lines=1, placeholder="Enter prompt (e.g. Extract table as OTSL)", label="Prompt")
#     ],
#     outputs="json",
#     title="Table Data Accuracy Checker (SmolDocling)",
#     description="Uploads two table images, extracts only <fcel> values from OTSL output, and compares them for accuracy."
# )

# demo.launch()

import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import torch
import html
import re
from PIL import Image, ImageOps

# Load model & processor once at startup
processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForVision2Seq.from_pretrained("ds4sd/SmolDocling-256M-preview", torch_dtype=torch.bfloat16).to("cuda")

def add_random_padding(image, min_percent=0.1, max_percent=0.10):
    image = image.convert("RGB")
    width, height = image.size
    pad_w_percent = random.uniform(min_percent, max_percent)
    pad_h_percent = random.uniform(min_percent, max_percent)
    pad_w = int(width * pad_w_percent)
    pad_h = int(height * pad_h_percent)
    corner_pixel = image.getpixel((0, 0))  # Top-left corner
    padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
    return padded_image

def extract_table(image_file):
    # Load image
    image = load_image(image_file)
    
    # Optionally add padding if needed for model robustness (optional)
    image = add_random_padding(image)

    # Fixed prompt to extract table only (modify if needed)
    text = "Convert this table to OTSL."

    # Build the message structure for processor
    resulting_messages = [{
        "role": "user",
        "content": [{"type": "image"}] + [{"type": "text", "text": text}]
    }]
    
    prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[image], return_tensors="pt").to('cuda')

    generation_args = {
        "input_ids": inputs.input_ids,
        "pixel_values": inputs.pixel_values,
        "attention_mask": inputs.attention_mask,
        "max_new_tokens": 8192,
        "num_return_sequences": 1,
    }

    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_args["streamer"] = streamer

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    output_text = ""
    for new_text in streamer:
        output_text += new_text

    # Clean and return output
    cleaned_output = output_text.replace("<end_of_utterance>", "").strip()
    
    # Optionally convert <chart> tags to <otsl> if present
    if "<chart>" in cleaned_output:
        cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
        cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)

    return cleaned_output or "No table found or unable to extract."

# Gradio UI
demo = gr.Interface(
    fn=extract_table,
    inputs=gr.Image(type="filepath", label="Upload Table Image"),
    outputs=gr.Textbox(label="Extracted Table (OTSL Format)"),
    title="Table Extraction from Image using SmolDocling-256M",
    description="Upload an image containing a table. The model will extract the table and output it in OTSL format."
)

demo.launch(debug=True)