AkashDataScience commited on
Commit
be84858
·
1 Parent(s): 1b60d22

Adding smoldocling OCR

Browse files
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -6,6 +6,14 @@ from docling.datamodel.base_models import InputFormat
6
  from paddleocr import PPStructureV3
7
  from pdf2image import convert_from_path
8
  import numpy as np
 
 
 
 
 
 
 
 
9
 
10
  pipeline_options = PdfPipelineOptions(enable_remote_services=True)
11
  converter = DocumentConverter(
@@ -14,6 +22,15 @@ converter = DocumentConverter(
14
  }
15
  )
16
 
 
 
 
 
 
 
 
 
 
17
  def get_pdf_page_count(pdf_path):
18
  reader = PdfReader(pdf_path)
19
  return len(reader.pages)
@@ -24,7 +41,6 @@ def get_docling_ocr(pdf_path, page_num):
24
  return markdown_text_docling
25
 
26
  def get_paddle_ocr(page_image):
27
- pipeline = PPStructureV3()
28
  output = pipeline.predict(input=np.array(page_image))
29
 
30
  markdown_list = []
@@ -36,13 +52,49 @@ def get_paddle_ocr(page_image):
36
  markdown_text_paddleOCR = pipeline.concatenate_markdown_pages(markdown_list)
37
  return markdown_text_paddleOCR
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def inference(pdf_path, page_num):
40
  docling_ocr = get_docling_ocr(pdf_path, page_num)
41
  # Extract the first page as an image
42
  images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
43
  page_image = images[0]
44
  paddle_ocr = get_paddle_ocr(page_image)
45
- return docling_ocr, paddle_ocr
 
46
 
47
  title = "OCR Arena"
48
  description = "A simple Gradio interface to extract text from PDFs and compare OCR models"
@@ -66,11 +118,12 @@ with gr.Blocks(theme=gr.themes.Glass()) as demo:
66
  clear_btn = gr.ClearButton(components=[pdf, page_num])
67
  submit_btn = gr.Button("Submit", variant='primary')
68
 
69
- submit_btn.click(inference, inputs=[pdf, page_num], outputs=[docling_ocr_out, paddle_ocr_out])
70
 
71
  with gr.Column():
72
  docling_ocr_out = gr.Textbox(label="Docling OCR Output", type="text")
73
  paddle_ocr_out = gr.Textbox(label="Paddle OCR Output", type="text")
 
74
 
75
  examples_obj = gr.Examples(examples=examples, inputs=[pdf])
76
 
 
6
  from paddleocr import PPStructureV3
7
  from pdf2image import convert_from_path
8
  import numpy as np
9
+ import torch
10
+ from docling_core.types.doc import DoclingDocument
11
+ from docling_core.types.doc.document import DocTagsDocument
12
+ from transformers import AutoProcessor, AutoModelForVision2Seq
13
+ from transformers.image_utils import load_image
14
+ from pathlib import Path
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  pipeline_options = PdfPipelineOptions(enable_remote_services=True)
19
  converter = DocumentConverter(
 
22
  }
23
  )
24
 
25
+ pipeline = PPStructureV3()
26
+
27
+ processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
28
+ model = AutoModelForVision2Seq.from_pretrained(
29
+ "ds4sd/SmolDocling-256M-preview",
30
+ torch_dtype=torch.bfloat16,
31
+ _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
32
+ ).to(DEVICE)
33
+
34
  def get_pdf_page_count(pdf_path):
35
  reader = PdfReader(pdf_path)
36
  return len(reader.pages)
 
41
  return markdown_text_docling
42
 
43
  def get_paddle_ocr(page_image):
 
44
  output = pipeline.predict(input=np.array(page_image))
45
 
46
  markdown_list = []
 
52
  markdown_text_paddleOCR = pipeline.concatenate_markdown_pages(markdown_list)
53
  return markdown_text_paddleOCR
54
 
55
+ def get_smoldocling_ocr(page_image):
56
+ image = load_image(page_image)
57
+
58
+ # Create input messages
59
+ messages = [
60
+ {
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "image"},
64
+ {"type": "text", "text": "Convert this page to docling."}
65
+ ]
66
+ },
67
+ ]
68
+
69
+ # Prepare inputs
70
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
71
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
72
+ inputs = inputs.to(DEVICE)
73
+
74
+ generated_ids = model.generate(**inputs, max_new_tokens=8192)
75
+ prompt_length = inputs.input_ids.shape[1]
76
+ trimmed_generated_ids = generated_ids[:, prompt_length:]
77
+ doctags = processor.batch_decode(
78
+ trimmed_generated_ids,
79
+ skip_special_tokens=False,
80
+ )[0].lstrip()
81
+
82
+ # Populate document
83
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
84
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
85
+
86
+ markdown_text_smoldocling = doc.export_to_markdown()
87
+ return markdown_text_smoldocling
88
+
89
+
90
  def inference(pdf_path, page_num):
91
  docling_ocr = get_docling_ocr(pdf_path, page_num)
92
  # Extract the first page as an image
93
  images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
94
  page_image = images[0]
95
  paddle_ocr = get_paddle_ocr(page_image)
96
+ smoldocling_ocr = get_smoldocling_ocr(page_image)
97
+ return docling_ocr, paddle_ocr, smoldocling_ocr
98
 
99
  title = "OCR Arena"
100
  description = "A simple Gradio interface to extract text from PDFs and compare OCR models"
 
118
  clear_btn = gr.ClearButton(components=[pdf, page_num])
119
  submit_btn = gr.Button("Submit", variant='primary')
120
 
121
+ submit_btn.click(inference, inputs=[pdf, page_num], outputs=[docling_ocr_out, paddle_ocr_out, smoldocling_ocr_out])
122
 
123
  with gr.Column():
124
  docling_ocr_out = gr.Textbox(label="Docling OCR Output", type="text")
125
  paddle_ocr_out = gr.Textbox(label="Paddle OCR Output", type="text")
126
+ smoldocling_ocr_out = gr.Textbox(label="SmolDocling OCR Output", type="text")
127
 
128
  examples_obj = gr.Examples(examples=examples, inputs=[pdf])
129