Pavan147 commited on
Commit
df46f51
·
verified ·
1 Parent(s): 0214886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -54
app.py CHANGED
@@ -1,65 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForImageTextToText
3
- from PIL import Image
 
 
 
4
  import re
 
5
 
6
- # Load SmolDocling model & processor once
7
  processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
8
- model = AutoModelForImageTextToText.from_pretrained("ds4sd/SmolDocling-256M-preview")
9
-
10
- def extract_fcel_values_from_image(image, prompt_text):
11
- """Run SmolDocling on an image and return numeric values inside <fcel> tags."""
12
- # Prepare prompt for the model
13
- messages = [
14
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
15
- ]
16
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
17
- inputs = processor(text=prompt, images=[image], return_tensors="pt")
18
-
19
- # Generate output
20
- outputs = model.generate(**inputs, max_new_tokens=2048)
21
- prompt_length = inputs.input_ids.shape[1]
22
- generated = outputs[:, prompt_length:]
23
- result = processor.batch_decode(generated, skip_special_tokens=False)[0]
24
- clean_text = result.replace("<end_of_utterance>", "").strip()
25
-
26
- # Extract only <fcel> values
27
- values = re.findall(r"<fcel>([\d.]+)", clean_text)
28
- values = [float(v) for v in values] # convert to floats
29
-
30
- return values, clean_text
31
-
32
- def compare_images(image1, image2, prompt_text):
33
- # Extract fcel values from both images
34
- values1, raw1 = extract_fcel_values_from_image(image1, prompt_text)
35
- values2, raw2 = extract_fcel_values_from_image(image2, prompt_text)
36
-
37
- # Calculate accuracy
38
- if len(values1) == len(values2) and values1 == values2:
39
- accuracy = 100.0
40
- else:
41
- matches = sum(1 for a, b in zip(values1, values2) if a == b)
42
- total = max(len(values1), len(values2))
43
- accuracy = (matches / total) * 100 if total > 0 else 0
44
-
45
- return {
46
- # "Extracted Values 1": values1,
47
- # "Extracted Values 2": values2,
48
- "Accuracy (%)": accuracy
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Gradio UI
52
  demo = gr.Interface(
53
- fn=compare_images,
54
- inputs=[
55
- gr.Image(type="pil", label="Upload First Table Image"),
56
- gr.Image(type="pil", label="Upload Second Table Image"),
57
- gr.Textbox(lines=1, placeholder="Enter prompt (e.g. Extract table as OTSL)", label="Prompt")
58
- ],
59
- outputs="json",
60
- title="Table Data Accuracy Checker (SmolDocling)",
61
- description="Uploads two table images, extracts only <fcel> values from OTSL output, and compares them for accuracy."
62
  )
63
 
64
- demo.launch()
 
65
 
 
1
+ # import gradio as gr
2
+ # from transformers import AutoProcessor, AutoModelForImageTextToText
3
+ # from PIL import Image
4
+ # import re
5
+
6
+ # # Load SmolDocling model & processor once
7
+ # processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
8
+ # model = AutoModelForImageTextToText.from_pretrained("ds4sd/SmolDocling-256M-preview")
9
+
10
+ # def extract_fcel_values_from_image(image, prompt_text):
11
+ # """Run SmolDocling on an image and return numeric values inside <fcel> tags."""
12
+ # # Prepare prompt for the model
13
+ # messages = [
14
+ # {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
15
+ # ]
16
+ # prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
17
+ # inputs = processor(text=prompt, images=[image], return_tensors="pt")
18
+
19
+ # # Generate output
20
+ # outputs = model.generate(**inputs, max_new_tokens=2048)
21
+ # prompt_length = inputs.input_ids.shape[1]
22
+ # generated = outputs[:, prompt_length:]
23
+ # result = processor.batch_decode(generated, skip_special_tokens=False)[0]
24
+ # clean_text = result.replace("<end_of_utterance>", "").strip()
25
+
26
+ # # Extract only <fcel> values
27
+ # values = re.findall(r"<fcel>([\d.]+)", clean_text)
28
+ # values = [float(v) for v in values] # convert to floats
29
+
30
+ # return values, clean_text
31
+
32
+ # def compare_images(image1, image2, prompt_text):
33
+ # # Extract fcel values from both images
34
+ # values1, raw1 = extract_fcel_values_from_image(image1, prompt_text)
35
+ # values2, raw2 = extract_fcel_values_from_image(image2, prompt_text)
36
+
37
+ # # Calculate accuracy
38
+ # if len(values1) == len(values2) and values1 == values2:
39
+ # accuracy = 100.0
40
+ # else:
41
+ # matches = sum(1 for a, b in zip(values1, values2) if a == b)
42
+ # total = max(len(values1), len(values2))
43
+ # accuracy = (matches / total) * 100 if total > 0 else 0
44
+
45
+ # return {
46
+ # # "Extracted Values 1": values1,
47
+ # # "Extracted Values 2": values2,
48
+ # "Accuracy (%)": accuracy
49
+ # }
50
+
51
+ # # Gradio UI
52
+ # demo = gr.Interface(
53
+ # fn=compare_images,
54
+ # inputs=[
55
+ # gr.Image(type="pil", label="Upload First Table Image"),
56
+ # gr.Image(type="pil", label="Upload Second Table Image"),
57
+ # gr.Textbox(lines=1, placeholder="Enter prompt (e.g. Extract table as OTSL)", label="Prompt")
58
+ # ],
59
+ # outputs="json",
60
+ # title="Table Data Accuracy Checker (SmolDocling)",
61
+ # description="Uploads two table images, extracts only <fcel> values from OTSL output, and compares them for accuracy."
62
+ # )
63
+
64
+ # demo.launch()
65
+
66
  import gradio as gr
67
+ from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
68
+ from transformers.image_utils import load_image
69
+ from threading import Thread
70
+ import torch
71
+ import html
72
  import re
73
+ from PIL import Image, ImageOps
74
 
75
+ # Load model & processor once at startup
76
  processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
77
+ model = AutoModelForVision2Seq.from_pretrained("ds4sd/SmolDocling-256M-preview", torch_dtype=torch.bfloat16).to("cuda")
78
+
79
+ def add_random_padding(image, min_percent=0.1, max_percent=0.10):
80
+ image = image.convert("RGB")
81
+ width, height = image.size
82
+ pad_w_percent = random.uniform(min_percent, max_percent)
83
+ pad_h_percent = random.uniform(min_percent, max_percent)
84
+ pad_w = int(width * pad_w_percent)
85
+ pad_h = int(height * pad_h_percent)
86
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
87
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
88
+ return padded_image
89
+
90
+ def extract_table(image_file):
91
+ # Load image
92
+ image = load_image(image_file)
93
+
94
+ # Optionally add padding if needed for model robustness (optional)
95
+ image = add_random_padding(image)
96
+
97
+ # Fixed prompt to extract table only (modify if needed)
98
+ text = "Convert this table to OTSL."
99
+
100
+ # Build the message structure for processor
101
+ resulting_messages = [{
102
+ "role": "user",
103
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
104
+ }]
105
+
106
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
107
+ inputs = processor(text=prompt, images=[image], return_tensors="pt").to('cuda')
108
+
109
+ generation_args = {
110
+ "input_ids": inputs.input_ids,
111
+ "pixel_values": inputs.pixel_values,
112
+ "attention_mask": inputs.attention_mask,
113
+ "max_new_tokens": 8192,
114
+ "num_return_sequences": 1,
 
 
 
115
  }
116
 
117
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
118
+ generation_args["streamer"] = streamer
119
+
120
+ thread = Thread(target=model.generate, kwargs=generation_args)
121
+ thread.start()
122
+
123
+ output_text = ""
124
+ for new_text in streamer:
125
+ output_text += new_text
126
+
127
+ # Clean and return output
128
+ cleaned_output = output_text.replace("<end_of_utterance>", "").strip()
129
+
130
+ # Optionally convert <chart> tags to <otsl> if present
131
+ if "<chart>" in cleaned_output:
132
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
133
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
134
+
135
+ return cleaned_output or "No table found or unable to extract."
136
+
137
  # Gradio UI
138
  demo = gr.Interface(
139
+ fn=extract_table,
140
+ inputs=gr.Image(type="filepath", label="Upload Table Image"),
141
+ outputs=gr.Textbox(label="Extracted Table (OTSL Format)"),
142
+ title="Table Extraction from Image using SmolDocling-256M",
143
+ description="Upload an image containing a table. The model will extract the table and output it in OTSL format."
 
 
 
 
144
  )
145
 
146
+ demo.launch(debug=True)
147
+
148