xiaoyao9184 commited on
Commit
4cdb157
·
verified ·
1 Parent(s): a19db79

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (2) hide show
  1. gradio_app.py +75 -152
  2. requirements.txt +4 -2
gradio_app.py CHANGED
@@ -16,13 +16,14 @@ from typing import List
16
  import pypdfium2
17
  import gradio as gr
18
 
19
- from surya.common.surya.schema import TaskNames
20
  from surya.models import load_predictors
21
 
22
  from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
23
 
24
  from surya.debug.text import draw_text_on_image
25
- from PIL import Image, ImageDraw
 
 
26
  from surya.table_rec import TableResult
27
  from surya.detection import TextDetectionResult
28
  from surya.recognition import OCRResult
@@ -32,18 +33,15 @@ from surya.common.util import rescale_bbox, expand_bbox
32
 
33
 
34
  # just copy from streamlit_app.py
35
- def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
36
  from pdftext.extraction import plain_text_output
37
-
38
  with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
  f.write(pdf_file.getvalue())
40
  f.seek(0)
41
 
42
  # Sample the text from the middle of the PDF
43
  page_middle = page_count // 2
44
- page_range = range(
45
- max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)
46
- )
47
  text = plain_text_output(f.name, page_range=page_range)
48
 
49
  sample_gap = len(text) // max_samples
@@ -56,14 +54,24 @@ def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=1
56
  # Split the text into samples for the model
57
  samples = []
58
  for i in range(0, len(text), sample_gap):
59
- samples.append(text[i : i + sample_len])
60
 
61
  results = predictors["ocr_error"](samples)
62
  label = "This PDF has good text."
63
- if results.labels.count("bad") / len(results.labels) > 0.2:
64
  label = "This PDF may have garbled or bad OCR text."
65
  return label, results.labels
66
 
 
 
 
 
 
 
 
 
 
 
67
  # just copy from streamlit_app.py
68
  def text_detection(img) -> (Image.Image, TextDetectionResult):
69
  text_pred = predictors["detection"]([img])[0]
@@ -75,35 +83,27 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):
75
  def layout_detection(img) -> (Image.Image, LayoutResult):
76
  pred = predictors["layout"]([img])[0]
77
  polygons = [p.polygon for p in pred.bboxes]
78
- labels = [
79
- f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes
80
- ]
81
- layout_img = draw_polys_on_image(
82
- polygons, img.copy(), labels=labels, label_font_size=18
83
- )
84
  return layout_img, pred
85
 
86
  # just copy from streamlit_app.py
87
- def table_recognition(
88
- img, highres_img, skip_table_detection: bool
89
- ) -> (Image.Image, List[TableResult]):
90
  if skip_table_detection:
91
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
92
  table_imgs = [highres_img]
93
  else:
94
  _, layout_pred = layout_detection(img)
95
- layout_tables_lowres = [
96
- line.bbox
97
- for line in layout_pred.bboxes
98
- if line.label in ["Table", "TableOfContents"]
99
- ]
100
  table_imgs = []
101
  layout_tables = []
102
  for tb in layout_tables_lowres:
103
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
104
  # Slightly expand the box
105
  highres_bbox = expand_bbox(highres_bbox)
106
- table_imgs.append(highres_img.crop(highres_bbox))
 
 
107
  layout_tables.append(highres_bbox)
108
 
109
  table_preds = predictors["table_rec"](table_imgs)
@@ -115,72 +115,29 @@ def table_recognition(
115
  colors = []
116
 
117
  for item in results.cells:
118
- adjusted_bboxes.append(
119
- [
120
- (item.bbox[0] + table_bbox[0]),
121
- (item.bbox[1] + table_bbox[1]),
122
- (item.bbox[2] + table_bbox[0]),
123
- (item.bbox[3] + table_bbox[1]),
124
- ]
125
- )
126
  labels.append(item.label)
127
  if "Row" in item.label:
128
  colors.append("blue")
129
  else:
130
  colors.append("red")
131
- table_img = draw_bboxes_on_image(
132
- adjusted_bboxes,
133
- highres_img,
134
- labels=labels,
135
- label_font_size=18,
136
- color=colors,
137
- )
138
  return table_img, table_preds
139
 
140
  # just copy from streamlit_app.py
141
- def ocr(
142
- img: Image.Image,
143
- highres_img: Image.Image,
144
- skip_text_detection: bool = False,
145
- recognize_math: bool = True,
146
- with_bboxes: bool = True,
147
- ) -> (Image.Image, OCRResult):
148
- if skip_text_detection:
149
- img = highres_img
150
- bboxes = [[[0, 0, img.width, img.height]]]
151
- else:
152
- bboxes = None
153
 
154
- if with_bboxes:
155
- tasks = [TaskNames.ocr_with_boxes]
156
- else:
157
- tasks = [TaskNames.ocr_without_boxes]
158
-
159
- img_pred = predictors["recognition"](
160
- [img],
161
- task_names=tasks,
162
- bboxes=bboxes,
163
- det_predictor=predictors["detection"],
164
- highres_images=[highres_img],
165
- math_mode=recognize_math,
166
- return_words=True,
167
- )[0]
168
-
169
- bboxes = [line.bbox for line in img_pred.text_lines]
170
- text = [line.text for line in img_pred.text_lines]
171
- rec_img = draw_text_on_image(bboxes, text, img.size)
172
-
173
- word_boxes = []
174
- for line in img_pred.text_lines:
175
- if line.words:
176
- word_boxes.extend([word.bbox for word in line.words])
177
-
178
- box_img = img.copy()
179
- draw = ImageDraw.Draw(box_img)
180
- for word_box in word_boxes:
181
- draw.rectangle(word_box, outline="red", width=2)
182
-
183
- return rec_img, img_pred, box_img
184
 
185
  def open_pdf(pdf_file):
186
  return pypdfium2.PdfDocument(pdf_file)
@@ -230,38 +187,20 @@ with gr.Blocks(title="Surya") as demo:
230
  in_num = gr.Slider(label="Page number", minimum=1, maximum=100, value=1, step=1)
231
  in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
232
 
233
- ocr_errors_btn = gr.Button("Run bad PDF text detection")
234
  text_det_btn = gr.Button("Run Text Detection")
 
235
  layout_det_btn = gr.Button("Run Layout Analysis")
236
 
237
- skip_text_detection_ckb = gr.Checkbox(label="Skip text detection", value=False, info="OCR only: Skip text detection and treat the whole image as a single line.")
238
- recognize_math_ckb = gr.Checkbox(label="Recognize math in OCR", value=True, info="Enable math mode in OCR - this will recognize math.")
239
- ocr_with_boxes_ckb = gr.Checkbox(label="OCR with boxes", value=True, info="Enable OCR with boxes - this will predict character-level boxes.")
240
  text_rec_btn = gr.Button("Run OCR")
241
 
 
242
  skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
243
  table_rec_btn = gr.Button("Run Table Rec")
 
 
244
  with gr.Column():
245
- result_img = gr.Gallery(label="Result images", show_label=True,
246
- elem_id="gallery", columns=[1], rows=[2], object_fit="contain", height="auto")
247
-
248
- gr.HTML("""
249
- <style>
250
- #gallery {
251
- height: auto !important;
252
- max-height: none !important;
253
- overflow: visible !important;
254
- }
255
- #gallery .gallery-item {
256
- flex-direction: column !important;
257
- }
258
- #gallery .gallery-item img {
259
- width: 100% !important;
260
- height: auto !important;
261
- object-fit: contain !important;
262
- }
263
- </style>
264
- """)
265
  result_json = gr.JSON(label="Result json")
266
 
267
  def show_image(file, num=1):
@@ -290,90 +229,74 @@ with gr.Blocks(title="Surya") as demo:
290
 
291
  # Run Text Detection
292
  def text_det_img(pil_image):
293
- det_img, pred = text_detection(pil_image)
294
- det_json = pred.model_dump(exclude=["heatmap", "affinity_map"])
295
- return (
296
- gr.update(label="Result image: text detected", value=[det_img], rows=[1], height=det_img.height),
297
- gr.update(label="Result json: " + str(len(det_json['bboxes'])) + " text boxes detected", value=det_json)
298
- )
299
  text_det_btn.click(
300
  fn=text_det_img,
301
  inputs=[in_img],
302
  outputs=[result_img, result_json]
303
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  # Run layout
306
  def layout_det_img(pil_image):
307
  layout_img, pred = layout_detection(pil_image)
308
- layout_json = pred.model_dump(exclude=["segmentation_map"])
309
- return (
310
- gr.update(label="Result image: layout detected", value=[layout_img], rows=[1], height=layout_img.height),
311
- gr.update(label="Result json: " + str(len(layout_json['bboxes'])) + " layout labels detected", value=layout_json)
312
- )
313
  layout_det_btn.click(
314
  fn=layout_det_img,
315
  inputs=[in_img],
316
  outputs=[result_img, result_json]
317
  )
318
-
319
  # Run OCR
320
- def text_rec_img(pil_image, in_file, page_number, skip_text_detection, recognize_math, ocr_with_boxes):
321
  if in_file.endswith('.pdf'):
322
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
323
  else:
324
  pil_image_highres = pil_image
325
- rec_img, pred, box_img = ocr(
326
- pil_image,
327
- pil_image_highres,
328
- skip_text_detection,
329
- recognize_math,
330
- with_bboxes=ocr_with_boxes,
331
- )
332
- text_img = [(rec_img, "Text"), (box_img, "Boxes")]
333
- text_json = pred.model_dump()
334
- return (
335
- gr.update(label="Result image: text recognized", value=text_img, rows=[2], height=rec_img.height + box_img.height),
336
- gr.update(label="Result json: " + str(len(text_json['text_lines'])) + " text lines recognized", value=text_json)
337
- )
338
  text_rec_btn.click(
339
  fn=text_rec_img,
340
- inputs=[in_img, in_file, in_num, skip_text_detection_ckb, recognize_math_ckb, ocr_with_boxes_ckb],
341
  outputs=[result_img, result_json]
342
  )
343
-
344
  # Run Table Recognition
345
- def table_rec_img(pil_image, in_file, page_number, skip_table_detection):
346
  if in_file.endswith('.pdf'):
347
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
348
  else:
349
  pil_image_highres = pil_image
350
  table_img, pred = table_recognition(pil_image, pil_image_highres, skip_table_detection)
351
- table_json = [p.model_dump() for p in pred]
352
- return (
353
- gr.update(label="Result image: table recognized", value=[table_img], rows=[1], height=table_img.height),
354
- gr.update(label="Result json: " + str(len(table_json)) + " table tree recognized", value=table_json)
355
- )
356
  table_rec_btn.click(
357
  fn=table_rec_img,
358
- inputs=[in_img, in_file, in_num, skip_table_detection_ckb],
359
  outputs=[result_img, result_json]
360
  )
361
-
362
  # Run bad PDF text detection
363
- def ocr_errors_pdf(in_file):
364
- if not in_file.endswith('.pdf'):
 
 
365
  raise gr.Error("This feature only works with PDFs.", duration=5)
366
- page_count = page_counter(in_file)
367
- io_file = io.BytesIO(open(in_file.name, "rb").read())
368
- layout_label, layout_json = ocr_errors(io_file, page_count)
369
- return (
370
- gr.update(label="Result image: NONE", value=None),
371
- gr.update(label="Result json: " + layout_label, value=layout_json)
372
- )
373
  ocr_errors_btn.click(
374
  fn=ocr_errors_pdf,
375
- inputs=[in_file],
376
- outputs=[result_img, result_json]
377
  )
378
 
379
  if __name__ == "__main__":
 
16
  import pypdfium2
17
  import gradio as gr
18
 
 
19
  from surya.models import load_predictors
20
 
21
  from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
22
 
23
  from surya.debug.text import draw_text_on_image
24
+
25
+ from PIL import Image
26
+ from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
27
  from surya.table_rec import TableResult
28
  from surya.detection import TextDetectionResult
29
  from surya.recognition import OCRResult
 
33
 
34
 
35
  # just copy from streamlit_app.py
36
+ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
37
  from pdftext.extraction import plain_text_output
 
38
  with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
  f.write(pdf_file.getvalue())
40
  f.seek(0)
41
 
42
  # Sample the text from the middle of the PDF
43
  page_middle = page_count // 2
44
+ page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
 
 
45
  text = plain_text_output(f.name, page_range=page_range)
46
 
47
  sample_gap = len(text) // max_samples
 
54
  # Split the text into samples for the model
55
  samples = []
56
  for i in range(0, len(text), sample_gap):
57
+ samples.append(text[i:i + sample_len])
58
 
59
  results = predictors["ocr_error"](samples)
60
  label = "This PDF has good text."
61
+ if results.labels.count("bad") / len(results.labels) > .2:
62
  label = "This PDF may have garbled or bad OCR text."
63
  return label, results.labels
64
 
65
+ # just copy from streamlit_app.py
66
+ def inline_detection(img) -> (Image.Image, TextDetectionResult):
67
+ text_pred = predictors["detection"]([img])[0]
68
+ text_boxes = [p.bbox for p in text_pred.bboxes]
69
+
70
+ inline_pred = predictors["inline_detection"]([img], [text_boxes], include_maps=True)[0]
71
+ inline_polygons = [p.polygon for p in inline_pred.bboxes]
72
+ det_img = draw_polys_on_image(inline_polygons, img.copy(), color='blue')
73
+ return det_img, text_pred, inline_pred
74
+
75
  # just copy from streamlit_app.py
76
  def text_detection(img) -> (Image.Image, TextDetectionResult):
77
  text_pred = predictors["detection"]([img])[0]
 
83
  def layout_detection(img) -> (Image.Image, LayoutResult):
84
  pred = predictors["layout"]([img])[0]
85
  polygons = [p.polygon for p in pred.bboxes]
86
+ labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
87
+ layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
 
 
 
 
88
  return layout_img, pred
89
 
90
  # just copy from streamlit_app.py
91
+ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
 
 
92
  if skip_table_detection:
93
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
94
  table_imgs = [highres_img]
95
  else:
96
  _, layout_pred = layout_detection(img)
97
+ layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label in ["Table", "TableOfContents"]]
 
 
 
 
98
  table_imgs = []
99
  layout_tables = []
100
  for tb in layout_tables_lowres:
101
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
102
  # Slightly expand the box
103
  highres_bbox = expand_bbox(highres_bbox)
104
+ table_imgs.append(
105
+ highres_img.crop(highres_bbox)
106
+ )
107
  layout_tables.append(highres_bbox)
108
 
109
  table_preds = predictors["table_rec"](table_imgs)
 
115
  colors = []
116
 
117
  for item in results.cells:
118
+ adjusted_bboxes.append([
119
+ (item.bbox[0] + table_bbox[0]),
120
+ (item.bbox[1] + table_bbox[1]),
121
+ (item.bbox[2] + table_bbox[0]),
122
+ (item.bbox[3] + table_bbox[1])
123
+ ])
 
 
124
  labels.append(item.label)
125
  if "Row" in item.label:
126
  colors.append("blue")
127
  else:
128
  colors.append("red")
129
+ table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
 
 
 
 
 
 
130
  return table_img, table_preds
131
 
132
  # just copy from streamlit_app.py
133
+ def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
134
+ replace_lang_with_code(langs)
135
+ img_pred = predictors["recognition"]([img], [langs], predictors["detection"], highres_images=[highres_img])[0]
 
 
 
 
 
 
 
 
 
136
 
137
+ bboxes = [l.bbox for l in img_pred.text_lines]
138
+ text = [l.text for l in img_pred.text_lines]
139
+ rec_img = draw_text_on_image(bboxes, text, img.size, langs)
140
+ return rec_img, img_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def open_pdf(pdf_file):
143
  return pypdfium2.PdfDocument(pdf_file)
 
187
  in_num = gr.Slider(label="Page number", minimum=1, maximum=100, value=1, step=1)
188
  in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
189
 
 
190
  text_det_btn = gr.Button("Run Text Detection")
191
+ inline_det_btn = gr.Button("Run Inline Math Detection")
192
  layout_det_btn = gr.Button("Run Layout Analysis")
193
 
194
+ lang_dd = gr.Dropdown(label="Languages", choices=sorted(list(CODE_TO_LANGUAGE.values())), multiselect=True, max_choices=4, info="Select the languages in the image (if known) to improve OCR accuracy. Optional.")
 
 
195
  text_rec_btn = gr.Button("Run OCR")
196
 
197
+ use_pdf_boxes_ckb = gr.Checkbox(label="Use PDF table boxes", value=True, info="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
198
  skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
199
  table_rec_btn = gr.Button("Run Table Rec")
200
+
201
+ ocr_errors_btn = gr.Button("Run bad PDF text detection")
202
  with gr.Column():
203
+ result_img = gr.Image(label="Result image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  result_json = gr.JSON(label="Result json")
205
 
206
  def show_image(file, num=1):
 
229
 
230
  # Run Text Detection
231
  def text_det_img(pil_image):
232
+ det_img, text_pred = text_detection(pil_image)
233
+ return det_img, text_pred.model_dump(exclude=["heatmap", "affinity_map"])
 
 
 
 
234
  text_det_btn.click(
235
  fn=text_det_img,
236
  inputs=[in_img],
237
  outputs=[result_img, result_json]
238
  )
239
+ def inline_det_img(pil_image):
240
+ det_img, text_pred, inline_pred = inline_detection(pil_image)
241
+ json = {
242
+ "text": text_pred.model_dump(exclude=["heatmap", "affinity_map"]),
243
+ "inline": inline_pred.model_dump(exclude=["heatmap", "affinity_map"])
244
+ }
245
+ return det_img, json
246
+ inline_det_btn.click(
247
+ fn=inline_det_img,
248
+ inputs=[in_img],
249
+ outputs=[result_img, result_json]
250
+ )
251
+
252
 
253
  # Run layout
254
  def layout_det_img(pil_image):
255
  layout_img, pred = layout_detection(pil_image)
256
+ return layout_img, pred.model_dump(exclude=["segmentation_map"])
 
 
 
 
257
  layout_det_btn.click(
258
  fn=layout_det_img,
259
  inputs=[in_img],
260
  outputs=[result_img, result_json]
261
  )
 
262
  # Run OCR
263
+ def text_rec_img(pil_image, in_file, page_number, languages):
264
  if in_file.endswith('.pdf'):
265
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
266
  else:
267
  pil_image_highres = pil_image
268
+ rec_img, pred = ocr(pil_image, pil_image_highres, languages)
269
+ return rec_img, pred.model_dump()
 
 
 
 
 
 
 
 
 
 
 
270
  text_rec_btn.click(
271
  fn=text_rec_img,
272
+ inputs=[in_img, in_file, in_num, lang_dd],
273
  outputs=[result_img, result_json]
274
  )
 
275
  # Run Table Recognition
276
+ def table_rec_img(pil_image, in_file, page_number, use_pdf_boxes, skip_table_detection):
277
  if in_file.endswith('.pdf'):
278
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
279
  else:
280
  pil_image_highres = pil_image
281
  table_img, pred = table_recognition(pil_image, pil_image_highres, skip_table_detection)
282
+ return table_img, [p.model_dump() for p in pred]
 
 
 
 
283
  table_rec_btn.click(
284
  fn=table_rec_img,
285
+ inputs=[in_img, in_file, in_num, use_pdf_boxes_ckb, skip_table_detection_ckb],
286
  outputs=[result_img, result_json]
287
  )
 
288
  # Run bad PDF text detection
289
+ def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
290
+ if file.endswith('.pdf'):
291
+ count = page_counter(file)
292
+ else:
293
  raise gr.Error("This feature only works with PDFs.", duration=5)
294
+ label, results = run_ocr_errors(io.BytesIO(open(file.name, "rb").read()), count)
295
+ return gr.update(label="Result json:" + label, value=results)
 
 
 
 
 
296
  ocr_errors_btn.click(
297
  fn=ocr_errors_pdf,
298
+ inputs=[in_file, in_num, use_pdf_boxes_ckb, skip_table_detection_ckb],
299
+ outputs=[result_json]
300
  )
301
 
302
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,9 +1,11 @@
1
- torch==2.7.0
2
- surya-ocr==0.14.5
3
  gradio==5.8.0
4
  huggingface-hub==0.26.3
5
  # gradio app need pdftext for run_ocr_errors
6
  pdftext==0.5.0
7
 
 
 
8
  # fix https://github.com/gradio-app/gradio/issues/10662
9
  pydantic==2.10.5
 
1
+ torch==2.5.1
2
+ surya-ocr==0.13.1
3
  gradio==5.8.0
4
  huggingface-hub==0.26.3
5
  # gradio app need pdftext for run_ocr_errors
6
  pdftext==0.5.0
7
 
8
+ # fix compatibility issue keep same with poetry lock file
9
+ transformers==4.48.1
10
  # fix https://github.com/gradio-app/gradio/issues/10662
11
  pydantic==2.10.5