Svngoku commited on
Commit
a000dd7
·
verified ·
1 Parent(s): 2567893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -166
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import os
2
  import base64
3
  import gradio as gr
4
- from mistralai import Mistral, ImageURLChunk
5
- from mistralai.models import OCRResponse
6
- from typing import Union, List, Tuple
7
  import requests
8
  import shutil
9
  import time
10
  import pymupdf as fitz
11
  import logging
 
 
 
12
  from tenacity import retry, stop_after_attempt, wait_exponential
13
  from concurrent.futures import ThreadPoolExecutor
14
- import socket
15
- from requests.exceptions import ConnectionError, Timeout
16
 
17
  # Constants
18
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
@@ -61,19 +60,15 @@ class OCRProcessor:
61
  def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
62
  clean_filename = os.path.basename(filename).replace(os.sep, "_")
63
  file_path = os.path.join(UPLOAD_FOLDER, f"{int(time.time())}_{clean_filename}")
64
-
65
  try:
66
  if isinstance(file_input, str) and file_input.startswith("http"):
67
- logger.info(f"Downloading from URL: {file_input}")
68
  response = requests.get(file_input, timeout=30)
69
  response.raise_for_status()
70
  with open(file_path, 'wb') as f:
71
  f.write(response.content)
72
  elif isinstance(file_input, str) and os.path.exists(file_input):
73
- logger.info(f"Copying local file: {file_input}")
74
  shutil.copy2(file_input, file_path)
75
  else:
76
- logger.info(f"Saving file object: {filename}")
77
  with open(file_path, 'wb') as f:
78
  if hasattr(file_input, 'read'):
79
  shutil.copyfileobj(file_input, f)
@@ -81,7 +76,6 @@ class OCRProcessor:
81
  f.write(file_input)
82
  if not os.path.exists(file_path):
83
  raise FileNotFoundError(f"Failed to save file at {file_path}")
84
- logger.info(f"File saved to: {file_path}")
85
  return file_path
86
  except Exception as e:
87
  logger.error(f"Error saving file {filename}: {str(e)}")
@@ -94,7 +88,7 @@ class OCRProcessor:
94
  return base64.b64encode(image_file.read()).decode('utf-8')
95
  except Exception as e:
96
  logger.error(f"Error encoding image {image_path}: {str(e)}")
97
- raise ValueError("Failed to encode image")
98
 
99
  @staticmethod
100
  def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
@@ -140,208 +134,179 @@ class OCRProcessor:
140
  document=ImageURLChunk(image_url=base64_url),
141
  include_image_base64=True
142
  )
143
- logger.info("OCR API call successful")
144
  return response
145
- except (ConnectionError, Timeout, socket.error) as e:
146
- logger.error(f"Network error during OCR API call: {str(e)}")
147
  raise
148
 
149
- def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str]]:
150
- file_name = getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf")
151
- logger.info(f"Processing uploaded PDF: {file_name}")
152
- try:
153
- self._check_file_size(pdf_file)
154
- pdf_path = self._save_uploaded_file(pdf_file, file_name)
155
-
156
- if not os.path.exists(pdf_path):
157
- raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
158
-
159
- image_data = self._pdf_to_images(pdf_path)
 
 
 
 
 
 
160
  if not image_data:
161
- raise ValueError("No pages converted from PDF")
162
 
163
  ocr_results = []
164
  image_paths = [path for path, _ in image_data]
165
- for i, (_, encoded) in enumerate(image_data):
166
  response = self._call_ocr_api(encoded)
167
- markdown_with_images = self._get_combined_markdown_with_images(response, image_paths, i)
168
- ocr_results.append(markdown_with_images)
169
-
170
- return "\n\n".join(ocr_results), image_paths
171
- except Exception as e:
172
- return self._handle_error("uploaded PDF processing", e), []
173
 
174
- def ocr_pdf_url(self, pdf_url: str) -> Tuple[str, List[str]]:
175
- logger.info(f"Processing PDF URL: {pdf_url}")
176
- try:
177
- file_name = pdf_url.split('/')[-1] or f"pdf_{int(time.time())}.pdf"
178
- pdf_path = self._save_uploaded_file(pdf_url, file_name)
179
-
180
- if not os.path.exists(pdf_path):
181
- raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
182
-
183
- image_data = self._pdf_to_images(pdf_path)
 
 
 
 
 
 
184
  if not image_data:
185
- raise ValueError("No pages converted from PDF")
186
 
187
  ocr_results = []
188
  image_paths = [path for path, _ in image_data]
189
- for i, (_, encoded) in enumerate(image_data):
190
  response = self._call_ocr_api(encoded)
191
- markdown_with_images = self._get_combined_markdown_with_images(response, image_paths, i)
192
- ocr_results.append(markdown_with_images)
193
-
194
- return "\n\n".join(ocr_results), image_paths
195
- except Exception as e:
196
- return self._handle_error("PDF URL processing", e), []
197
-
198
- def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str]:
199
- file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg")
200
- logger.info(f"Processing uploaded image: {file_name}")
201
- try:
202
- self._check_file_size(image_file)
203
- image_path = self._save_uploaded_file(image_file, file_name)
204
- encoded_image = self._encode_image(image_path)
205
- response = self._call_ocr_api(encoded_image)
206
- return self._get_combined_markdown_with_images(response), image_path
207
- except Exception as e:
208
- return self._handle_error("image processing", e), None
209
 
210
  @staticmethod
211
- def _get_combined_markdown_with_images(response: OCRResponse, image_paths: List[str] = None, page_index: int = None) -> str:
 
212
  markdown_parts = []
213
- for i, page in enumerate(response.pages):
214
- if page.markdown.strip():
215
- markdown = page.markdown
216
- logger.info(f"Page {i} markdown: {markdown}")
217
- if hasattr(page, 'images') and page.images:
218
- logger.info(f"Found {len(page.images)} images in page {i}")
219
- for img in page.images:
220
- if img.image_base64:
221
- logger.info(f"Replacing image {img.id} with base64")
222
- markdown = markdown.replace(
223
- f"![{img.id}]({img.id})",
224
- f"![{img.id}](data:image/png;base64,{img.image_base64})"
225
- )
226
- else:
227
- logger.warning(f"No base64 data for image {img.id}")
228
- if image_paths and page_index is not None and page_index < len(image_paths):
229
- local_encoded = OCRProcessor._encode_image(image_paths[page_index])
230
- markdown = markdown.replace(
231
- f"![{img.id}]({img.id})",
232
- f"![{img.id}](data:image/png;base64,{local_encoded})"
233
- )
234
- else:
235
- logger.warning(f"No images found in page {i}")
236
- # Replace known placeholders or append the local image
237
- if image_paths and page_index is not None and page_index < len(image_paths):
238
- local_encoded = OCRProcessor._encode_image(image_paths[page_index])
239
- # Replace placeholders like img-0.jpeg
240
- placeholder = f"img-{i}.jpeg"
241
- if placeholder in markdown:
242
- markdown = markdown.replace(
243
- placeholder,
244
- f"![Page {i} Image](data:image/png;base64,{local_encoded})"
245
- )
246
- else:
247
- # Append the image if no placeholder is found
248
- markdown += f"\n\n![Page {i} Image](data:image/png;base64,{local_encoded})"
249
- markdown_parts.append(markdown)
250
- return "\n\n".join(markdown_parts) or "No text or images detected"
251
-
252
- @staticmethod
253
- def _handle_error(context: str, error: Exception) -> str:
254
- logger.error(f"Error in {context}: {str(error)}")
255
- return f"**Error in {context}:** {str(error)}"
256
 
257
  def create_interface():
258
  css = """
259
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
260
  .status {color: #666; font-style: italic;}
 
261
  """
262
-
263
  with gr.Blocks(title="Mistral OCR Demo", css=css) as demo:
264
- gr.Markdown("# Mistral OCR App\nUpload images or PDFs, or provide a PDF URL for OCR processing")
265
-
 
 
 
 
 
 
266
  with gr.Row():
267
- api_key = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key")
268
  set_key_btn = gr.Button("Set API Key", variant="primary")
269
-
270
  processor_state = gr.State()
271
  status = gr.Markdown("Please enter API key", elem_classes="status")
272
 
273
  def init_processor(key):
274
  try:
275
  processor = OCRProcessor(key)
276
- return processor, "✅ API key validated successfully"
277
  except Exception as e:
278
  return None, f"❌ Error: {str(e)}"
 
 
279
 
280
- set_key_btn.click(
281
- fn=init_processor,
282
- inputs=api_key,
283
- outputs=[processor_state, status]
284
- )
285
-
286
- with gr.Tab("Image OCR"):
287
  with gr.Row():
288
- image_input = gr.File(
289
- label=f"Upload Image (max {MAX_FILE_SIZE/1024/1024}MB)",
290
- file_types=SUPPORTED_IMAGE_TYPES
291
- )
292
- image_preview = gr.Image(label="Preview", height=300)
293
- image_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
294
- process_image_btn = gr.Button("Process Image", variant="primary")
295
 
296
- def process_image(processor, image):
297
- if not processor or not image:
298
- return "Please set API key and upload an image", None
299
- return processor.ocr_uploaded_image(image)
300
 
301
- process_image_btn.click(
302
- fn=process_image,
303
- inputs=[processor_state, image_input],
304
- outputs=[image_output, image_preview]
 
305
  )
306
 
307
- with gr.Tab("PDF OCR"):
 
308
  with gr.Row():
309
- with gr.Column():
310
- pdf_input = gr.File(
311
- label=f"Upload PDF (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages)",
312
- file_types=SUPPORTED_PDF_TYPES
313
- )
314
- pdf_url_input = gr.Textbox(
315
- label="Or Enter PDF URL",
316
- placeholder="e.g., https://arxiv.org/pdf/2201.04234.pdf"
317
- )
318
- pdf_gallery = gr.Gallery(label="PDF Pages", height=300)
319
- pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
320
- process_pdf_btn = gr.Button("Process PDF", variant="primary")
321
 
322
- def process_pdf(processor, pdf_file, pdf_url):
323
- if not processor:
324
- return "Please set API key first", []
325
- logger.info(f"Received inputs - PDF file: {pdf_file}, PDF URL: {pdf_url}")
326
- if pdf_file is not None and hasattr(pdf_file, 'name'):
327
- logger.info(f"Processing as uploaded PDF: {pdf_file.name}")
328
- return processor.ocr_uploaded_pdf(pdf_file)
329
- elif pdf_url and pdf_url.strip():
330
- logger.info(f"Processing as PDF URL: {pdf_url}")
331
- return processor.ocr_pdf_url(pdf_url)
332
- return "Please upload a PDF or provide a valid URL", []
 
333
 
334
- process_pdf_btn.click(
335
- fn=process_pdf,
336
- inputs=[processor_state, pdf_input, pdf_url_input],
337
- outputs=[pdf_output, pdf_gallery]
 
338
  )
339
 
 
 
 
 
 
 
 
 
 
340
  return demo
341
 
342
  if __name__ == "__main__":
343
  os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S')
344
  print(f"===== Application Startup at {os.environ['START_TIME']} =====")
345
- create_interface().launch(
346
- share=True,
347
- )
 
1
  import os
2
  import base64
3
  import gradio as gr
 
 
 
4
  import requests
5
  import shutil
6
  import time
7
  import pymupdf as fitz
8
  import logging
9
+ from mistralai import Mistral, ImageURLChunk
10
+ from mistralai.models import OCRResponse
11
+ from typing import Union, List, Tuple, Optional, Dict
12
  from tenacity import retry, stop_after_attempt, wait_exponential
13
  from concurrent.futures import ThreadPoolExecutor
14
+ import tempfile
 
15
 
16
  # Constants
17
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
 
60
  def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
61
  clean_filename = os.path.basename(filename).replace(os.sep, "_")
62
  file_path = os.path.join(UPLOAD_FOLDER, f"{int(time.time())}_{clean_filename}")
 
63
  try:
64
  if isinstance(file_input, str) and file_input.startswith("http"):
 
65
  response = requests.get(file_input, timeout=30)
66
  response.raise_for_status()
67
  with open(file_path, 'wb') as f:
68
  f.write(response.content)
69
  elif isinstance(file_input, str) and os.path.exists(file_input):
 
70
  shutil.copy2(file_input, file_path)
71
  else:
 
72
  with open(file_path, 'wb') as f:
73
  if hasattr(file_input, 'read'):
74
  shutil.copyfileobj(file_input, f)
 
76
  f.write(file_input)
77
  if not os.path.exists(file_path):
78
  raise FileNotFoundError(f"Failed to save file at {file_path}")
 
79
  return file_path
80
  except Exception as e:
81
  logger.error(f"Error saving file {filename}: {str(e)}")
 
88
  return base64.b64encode(image_file.read()).decode('utf-8')
89
  except Exception as e:
90
  logger.error(f"Error encoding image {image_path}: {str(e)}")
91
+ raise ValueError(f"Failed to encode image: {str(e)}")
92
 
93
  @staticmethod
94
  def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
 
134
  document=ImageURLChunk(image_url=base64_url),
135
  include_image_base64=True
136
  )
 
137
  return response
138
+ except Exception as e:
139
+ logger.error(f"OCR API call failed: {str(e)}")
140
  raise
141
 
142
+ def process_file(self, file: gr.File) -> Tuple[str, str, List[str]]:
143
+ """Process uploaded file (image or PDF)."""
144
+ if not file:
145
+ return "## No file provided", "", []
146
+
147
+ file_name = file.name
148
+ self._check_file_size(file)
149
+ file_path = self._save_uploaded_file(file, file_name)
150
+
151
+ if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)):
152
+ encoded_image = self._encode_image(file_path)
153
+ response = self._call_ocr_api(encoded_image)
154
+ markdown = self._combine_markdown(response)
155
+ return markdown, file_path, [file_path]
156
+
157
+ elif file_name.lower().endswith('.pdf'):
158
+ image_data = self._pdf_to_images(file_path)
159
  if not image_data:
160
+ return "## No pages converted from PDF", file_path, []
161
 
162
  ocr_results = []
163
  image_paths = [path for path, _ in image_data]
164
+ for _, encoded in image_data:
165
  response = self._call_ocr_api(encoded)
166
+ markdown = self._combine_markdown(response)
167
+ ocr_results.append(markdown)
168
+ return "\n\n".join(ocr_results), file_path, image_paths
169
+
170
+ return "## Unsupported file type", file_path, []
 
171
 
172
+ def process_url(self, url: str) -> Tuple[str, str, List[str]]:
173
+ """Process URL (image or PDF)."""
174
+ if not url:
175
+ return "## No URL provided", "", []
176
+
177
+ file_name = url.split('/')[-1] or f"file_{int(time.time())}"
178
+ file_path = self._save_uploaded_file(url, file_name)
179
+
180
+ if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)):
181
+ encoded_image = self._encode_image(file_path)
182
+ response = self._call_ocr_api(encoded_image)
183
+ markdown = self._combine_markdown(response)
184
+ return markdown, url, [file_path]
185
+
186
+ elif file_name.lower().endswith('.pdf'):
187
+ image_data = self._pdf_to_images(file_path)
188
  if not image_data:
189
+ return "## No pages converted from PDF", url, []
190
 
191
  ocr_results = []
192
  image_paths = [path for path, _ in image_data]
193
+ for _, encoded in image_data:
194
  response = self._call_ocr_api(encoded)
195
+ markdown = self._combine_markdown(response)
196
+ ocr_results.append(markdown)
197
+ return "\n\n".join(ocr_results), url, image_paths
198
+
199
+ return "## Unsupported URL content type", url, []
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  @staticmethod
202
+ def _combine_markdown(response: OCRResponse) -> str:
203
+ """Combine markdown from OCR response."""
204
  markdown_parts = []
205
+ for page in response.pages:
206
+ if not page.markdown.strip():
207
+ continue
208
+ markdown = page.markdown
209
+ if hasattr(page, 'images') and page.images:
210
+ for img in page.images:
211
+ if img.image_base64:
212
+ markdown = markdown.replace(
213
+ f"![{img.id}]({img.id})",
214
+ f"![{img.id}](data:image/png;base64,{img.image_base64})"
215
+ )
216
+ markdown_parts.append(markdown)
217
+ return "\n\n".join(markdown_parts) or "## No text detected"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  def create_interface():
220
  css = """
221
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
222
  .status {color: #666; font-style: italic;}
223
+ .preview {max-height: 300px;}
224
  """
225
+
226
  with gr.Blocks(title="Mistral OCR Demo", css=css) as demo:
227
+ gr.Markdown("# Mistral OCR Demo")
228
+ gr.Markdown(f"""
229
+ Process PDFs and images (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages for PDFs) via upload or URL.
230
+ View previews and OCR results with embedded images.
231
+ Learn more at [Mistral OCR](https://mistral.ai/news/mistral-ocr).
232
+ """)
233
+
234
+ # API Key Setup
235
  with gr.Row():
236
+ api_key_input = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key")
237
  set_key_btn = gr.Button("Set API Key", variant="primary")
 
238
  processor_state = gr.State()
239
  status = gr.Markdown("Please enter API key", elem_classes="status")
240
 
241
  def init_processor(key):
242
  try:
243
  processor = OCRProcessor(key)
244
+ return processor, "✅ API key validated"
245
  except Exception as e:
246
  return None, f"❌ Error: {str(e)}"
247
+
248
+ set_key_btn.click(fn=init_processor, inputs=api_key_input, outputs=[processor_state, status])
249
 
250
+ # File Upload Tab
251
+ with gr.Tab("Upload File"):
 
 
 
 
 
252
  with gr.Row():
253
+ file_input = gr.File(label="Upload PDF/Image", file_types=SUPPORTED_IMAGE_TYPES + SUPPORTED_PDF_TYPES)
254
+ file_preview = gr.Gallery(label="Preview", elem_classes="preview")
255
+ file_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
256
+ file_raw_output = gr.Textbox(label="Raw File Path")
257
+ file_button = gr.Button("Process", variant="primary")
 
 
258
 
259
+ def update_file_preview(file):
260
+ return [file.name] if file else []
 
 
261
 
262
+ file_input.change(fn=update_file_preview, inputs=file_input, outputs=file_preview)
263
+ file_button.click(
264
+ fn=lambda p, f: p.process_file(f) if p else ("## Set API key first", "", []),
265
+ inputs=[processor_state, file_input],
266
+ outputs=[file_output, file_raw_output, file_preview]
267
  )
268
 
269
+ # URL Tab
270
+ with gr.Tab("URL Input"):
271
  with gr.Row():
272
+ url_input = gr.Textbox(label="URL to PDF/Image")
273
+ url_preview = gr.Gallery(label="Preview", elem_classes="preview")
274
+ url_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
275
+ url_raw_output = gr.Textbox(label="Raw URL")
276
+ url_button = gr.Button("Process", variant="primary")
 
 
 
 
 
 
 
277
 
278
+ def update_url_preview(url):
279
+ if not url:
280
+ return []
281
+ try:
282
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.tmp')
283
+ response = requests.get(url, timeout=10)
284
+ temp_file.write(response.content)
285
+ temp_file.close()
286
+ return [temp_file.name]
287
+ except Exception as e:
288
+ logger.error(f"URL preview error: {str(e)}")
289
+ return []
290
 
291
+ url_input.change(fn=update_url_preview, inputs=url_input, outputs=url_preview)
292
+ url_button.click(
293
+ fn=lambda p, u: p.process_url(u) if p else ("## Set API key first", "", []),
294
+ inputs=[processor_state, url_input],
295
+ outputs=[url_output, url_raw_output, url_preview]
296
  )
297
 
298
+ # Examples
299
+ gr.Examples(
300
+ examples=[
301
+ {"file_input": "receipt.png"},
302
+ {"url_input": "https://arxiv.org/pdf/2410.07073"}
303
+ ],
304
+ inputs=[file_input, url_input]
305
+ )
306
+
307
  return demo
308
 
309
  if __name__ == "__main__":
310
  os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S')
311
  print(f"===== Application Startup at {os.environ['START_TIME']} =====")
312
+ create_interface().launch(share=True, max_threads=1)