Svngoku commited on
Commit
86ba735
·
verified ·
1 Parent(s): 167a8e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -146
app.py CHANGED
@@ -1,26 +1,20 @@
1
  import os
2
  import base64
3
  import gradio as gr
4
- from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
5
  from mistralai.models import OCRResponse
6
- from pathlib import Path
7
- import pycountry
8
- import json
9
- import logging
10
- from tenacity import retry, stop_after_attempt, wait_exponential
11
- import tempfile
12
- from typing import Union, Dict, List, Optional, Tuple
13
- from contextlib import contextmanager
14
  import requests
15
  import shutil
16
- from concurrent.futures import ThreadPoolExecutor
17
  import time
 
 
 
 
18
 
19
  # Constants
20
- DEFAULT_LANGUAGE = "English"
21
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
22
  SUPPORTED_PDF_TYPES = [".pdf"]
23
- TEMP_FILE_EXPIRY = 7200 # 2 hours in seconds
24
  UPLOAD_FOLDER = "./uploads"
25
  MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
26
  MAX_PDF_PAGES = 50
@@ -36,15 +30,10 @@ logger = logging.getLogger(__name__)
36
 
37
  class OCRProcessor:
38
  def __init__(self, api_key: str):
39
- self.api_key = self._validate_api_key(api_key)
40
- self.client = Mistral(api_key=self.api_key)
41
- self._validate_client()
42
-
43
- @staticmethod
44
- def _validate_api_key(api_key: str) -> str:
45
  if not api_key or not isinstance(api_key, str):
46
  raise ValueError("Valid API key must be provided")
47
- return api_key
 
48
 
49
  def _validate_client(self) -> None:
50
  try:
@@ -60,21 +49,12 @@ class OCRProcessor:
60
  size = os.path.getsize(file_input)
61
  elif hasattr(file_input, 'read'):
62
  size = len(file_input.read())
63
- file_input.seek(0) # Reset file pointer
64
  else:
65
  size = len(file_input)
66
  if size > MAX_FILE_SIZE:
67
  raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit")
68
 
69
- @staticmethod
70
- def _encode_image(image_path: str) -> Optional[str]:
71
- try:
72
- with open(image_path, "rb") as image_file:
73
- return base64.b64encode(image_file.read()).decode('utf-8')
74
- except Exception as e:
75
- logger.error(f"Error encoding image {image_path}: {str(e)}")
76
- return None
77
-
78
  @staticmethod
79
  def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
80
  clean_filename = os.path.basename(filename).replace(os.sep, "_")
@@ -102,7 +82,16 @@ class OCRProcessor:
102
  raise
103
 
104
  @staticmethod
105
- def _pdf_to_images(pdf_path: str) -> List[str]:
 
 
 
 
 
 
 
 
 
106
  try:
107
  pdf_document = fitz.open(pdf_path)
108
  if pdf_document.page_count > MAX_PDF_PAGES:
@@ -110,42 +99,40 @@ class OCRProcessor:
110
  raise ValueError(f"PDF exceeds maximum page limit of {MAX_PDF_PAGES}")
111
 
112
  with ThreadPoolExecutor() as executor:
113
- image_paths = list(executor.map(
114
  lambda i: OCRProcessor._convert_page(pdf_path, i),
115
  range(pdf_document.page_count)
116
  ))
117
  pdf_document.close()
118
- return [path for path in image_paths if path]
119
  except Exception as e:
120
  logger.error(f"Error converting PDF to images: {str(e)}")
121
  return []
122
 
123
  @staticmethod
124
- def _convert_page(pdf_path: str, page_num: int) -> Optional[str]:
125
  try:
126
  pdf_document = fitz.open(pdf_path)
127
  page = pdf_document[page_num]
128
  pix = page.get_pixmap(dpi=150)
129
  image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}_{int(time.time())}.png")
130
  pix.save(image_path)
 
131
  pdf_document.close()
132
- return image_path
133
  except Exception as e:
134
  logger.error(f"Error converting page {page_num}: {str(e)}")
135
- return None
136
 
137
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
138
- def _call_ocr_api(self, document: Union[DocumentURLChunk, ImageURLChunk]) -> OCRResponse:
 
139
  return self.client.ocr.process(
140
  model="mistral-ocr-latest",
141
- document=document,
142
  include_image_base64=True
143
  )
144
 
145
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
146
- def _call_chat_complete(self, model: str, messages: List[Dict], **kwargs) -> Dict:
147
- return self.client.chat.complete(model=model, messages=messages, **kwargs)
148
-
149
  def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str]]:
150
  file_name = getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf")
151
  logger.info(f"Processing uploaded PDF: {file_name}")
@@ -157,18 +144,45 @@ class OCRProcessor:
157
  if not os.path.exists(pdf_path):
158
  raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
159
 
160
- image_paths = self._pdf_to_images(pdf_path)
 
 
161
 
162
- with open(pdf_path, "rb") as f:
163
- uploaded_file = self.client.files.upload(
164
- file={"file_name": file_name, "content": f},
165
- purpose="ocr"
166
- )
167
- signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
168
- response = self._call_ocr_api(DocumentURLChunk(document_url=signed_url.url))
169
- return self._get_combined_markdown(response), image_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  except Exception as e:
171
- return self._handle_error("PDF processing", e), []
172
 
173
  def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str]:
174
  file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg")
@@ -177,61 +191,11 @@ class OCRProcessor:
177
  self._check_file_size(image_file)
178
  image_path = self._save_uploaded_file(image_file, file_name)
179
  encoded_image = self._encode_image(image_path)
180
- if not encoded_image:
181
- raise ValueError("Failed to encode image")
182
- base64_url = f"data:image/jpeg;base64,{encoded_image}"
183
- response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
184
  return self._get_combined_markdown(response), image_path
185
  except Exception as e:
186
  return self._handle_error("image processing", e), None
187
 
188
- def document_understanding(self, doc_url: str, question: str) -> str:
189
- try:
190
- messages = [{"role": "user", "content": [
191
- TextChunk(text=question),
192
- DocumentURLChunk(document_url=doc_url)
193
- ]}]
194
- response = self._call_chat_complete(
195
- model="mistral-small-latest",
196
- messages=messages,
197
- temperature=0.1
198
- )
199
- return response.choices[0].message.content
200
- except Exception as e:
201
- return self._handle_error("document understanding", e)
202
-
203
- def structured_ocr(self, image_file: Union[str, bytes]) -> Tuple[str, str]:
204
- file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg")
205
- try:
206
- self._check_file_size(image_file)
207
- image_path = self._save_uploaded_file(image_file, file_name)
208
- encoded_image = self._encode_image(image_path)
209
- if not encoded_image:
210
- raise ValueError("Failed to encode image")
211
- base64_url = f"data:image/jpeg;base64,{encoded_image}"
212
-
213
- ocr_response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
214
- markdown = self._get_combined_markdown(ocr_response)
215
-
216
- chat_response = self._call_chat_complete(
217
- model="pixtral-12b-latest",
218
- messages=[{
219
- "role": "user",
220
- "content": [
221
- ImageURLChunk(image_url=base64_url),
222
- TextChunk(text=(
223
- f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>.\n"
224
- "Convert this into a structured JSON response with file_name, topics, languages, and ocr_contents fields"
225
- ))
226
- ]
227
- }],
228
- response_format={"type": "json_object"},
229
- temperature=0.1
230
- )
231
- return self._format_structured_response(image_path, json.loads(chat_response.choices[0].message.content)), image_path
232
- except Exception as e:
233
- return self._handle_error("structured OCR", e), None
234
-
235
  @staticmethod
236
  def _get_combined_markdown(response: OCRResponse) -> str:
237
  return "\n\n".join(
@@ -244,20 +208,6 @@ class OCRProcessor:
244
  logger.error(f"Error in {context}: {str(error)}")
245
  return f"**Error in {context}:** {str(error)}"
246
 
247
- @staticmethod
248
- def _format_structured_response(file_path: str, content: Dict) -> str:
249
- languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
250
- content_languages = content.get("languages", [DEFAULT_LANGUAGE])
251
- valid_langs = [l for l in content_languages if l in languages.values()] or [DEFAULT_LANGUAGE]
252
-
253
- response = {
254
- "file_name": Path(file_path).name,
255
- "topics": content.get("topics", []),
256
- "languages": valid_langs,
257
- "ocr_contents": content.get("ocr_contents", {})
258
- }
259
- return f"```json\n{json.dumps(response, indent=2, ensure_ascii=False)}\n```"
260
-
261
  def create_interface():
262
  css = """
263
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
@@ -265,7 +215,7 @@ def create_interface():
265
  """
266
 
267
  with gr.Blocks(title="Mistral OCR App", css=css) as demo:
268
- gr.Markdown("# Mistral OCR App\nUpload images or PDFs for OCR processing")
269
 
270
  with gr.Row():
271
  api_key = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key")
@@ -310,45 +260,34 @@ def create_interface():
310
 
311
  with gr.Tab("PDF OCR"):
312
  with gr.Row():
313
- pdf_input = gr.File(
314
- label=f"Upload PDF (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages)",
315
- file_types=SUPPORTED_PDF_TYPES
316
- )
 
 
 
 
 
317
  pdf_gallery = gr.Gallery(label="PDF Pages", height=300)
318
  pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
319
  process_pdf_btn = gr.Button("Process PDF", variant="primary")
320
 
321
- def process_pdf(processor, pdf):
322
- if not processor or not pdf:
323
- return "Please set API key and upload a PDF", []
324
- return processor.ocr_uploaded_pdf(pdf)
 
 
 
 
325
 
326
  process_pdf_btn.click(
327
  fn=process_pdf,
328
- inputs=[processor_state, pdf_input],
329
  outputs=[pdf_output, pdf_gallery]
330
  )
331
 
332
- with gr.Tab("Structured OCR"):
333
- structured_input = gr.File(
334
- label=f"Upload Image for Structured OCR (max {MAX_FILE_SIZE/1024/1024}MB)",
335
- file_types=SUPPORTED_IMAGE_TYPES
336
- )
337
- structured_output = gr.Markdown(label="Structured Result", elem_classes="output-markdown")
338
- structured_preview = gr.Image(label="Preview", height=300)
339
- process_structured_btn = gr.Button("Process Structured OCR", variant="primary")
340
-
341
- def process_structured(processor, image):
342
- if not processor or not image:
343
- return "Please set API key and upload an image", None
344
- return processor.structured_ocr(image)
345
-
346
- process_structured_btn.click(
347
- fn=process_structured,
348
- inputs=[processor_state, structured_input],
349
- outputs=[structured_output, structured_preview]
350
- )
351
-
352
  return demo
353
 
354
  if __name__ == "__main__":
 
1
  import os
2
  import base64
3
  import gradio as gr
4
+ from mistralai import Mistral, ImageURLChunk
5
  from mistralai.models import OCRResponse
6
+ from typing import Union, List, Tuple
 
 
 
 
 
 
 
7
  import requests
8
  import shutil
 
9
  import time
10
+ import pymupdf as fitz
11
+ import logging
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+ from concurrent.futures import ThreadPoolExecutor
14
 
15
  # Constants
 
16
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
17
  SUPPORTED_PDF_TYPES = [".pdf"]
 
18
  UPLOAD_FOLDER = "./uploads"
19
  MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
20
  MAX_PDF_PAGES = 50
 
30
 
31
  class OCRProcessor:
32
  def __init__(self, api_key: str):
 
 
 
 
 
 
33
  if not api_key or not isinstance(api_key, str):
34
  raise ValueError("Valid API key must be provided")
35
+ self.client = Mistral(api_key=api_key)
36
+ self._validate_client()
37
 
38
  def _validate_client(self) -> None:
39
  try:
 
49
  size = os.path.getsize(file_input)
50
  elif hasattr(file_input, 'read'):
51
  size = len(file_input.read())
52
+ file_input.seek(0)
53
  else:
54
  size = len(file_input)
55
  if size > MAX_FILE_SIZE:
56
  raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit")
57
 
 
 
 
 
 
 
 
 
 
58
  @staticmethod
59
  def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
60
  clean_filename = os.path.basename(filename).replace(os.sep, "_")
 
82
  raise
83
 
84
  @staticmethod
85
+ def _encode_image(image_path: str) -> str:
86
+ try:
87
+ with open(image_path, "rb") as image_file:
88
+ return base64.b64encode(image_file.read()).decode('utf-8')
89
+ except Exception as e:
90
+ logger.error(f"Error encoding image {image_path}: {str(e)}")
91
+ raise ValueError("Failed to encode image")
92
+
93
+ @staticmethod
94
+ def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
95
  try:
96
  pdf_document = fitz.open(pdf_path)
97
  if pdf_document.page_count > MAX_PDF_PAGES:
 
99
  raise ValueError(f"PDF exceeds maximum page limit of {MAX_PDF_PAGES}")
100
 
101
  with ThreadPoolExecutor() as executor:
102
+ image_data = list(executor.map(
103
  lambda i: OCRProcessor._convert_page(pdf_path, i),
104
  range(pdf_document.page_count)
105
  ))
106
  pdf_document.close()
107
+ return [data for data in image_data if data]
108
  except Exception as e:
109
  logger.error(f"Error converting PDF to images: {str(e)}")
110
  return []
111
 
112
  @staticmethod
113
+ def _convert_page(pdf_path: str, page_num: int) -> Tuple[str, str]:
114
  try:
115
  pdf_document = fitz.open(pdf_path)
116
  page = pdf_document[page_num]
117
  pix = page.get_pixmap(dpi=150)
118
  image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}_{int(time.time())}.png")
119
  pix.save(image_path)
120
+ encoded = OCRProcessor._encode_image(image_path)
121
  pdf_document.close()
122
+ return image_path, encoded
123
  except Exception as e:
124
  logger.error(f"Error converting page {page_num}: {str(e)}")
125
+ return None, None
126
 
127
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
128
+ def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
129
+ base64_url = f"data:image/png;base64,{encoded_image}"
130
  return self.client.ocr.process(
131
  model="mistral-ocr-latest",
132
+ document=ImageURLChunk(image_url=base64_url),
133
  include_image_base64=True
134
  )
135
 
 
 
 
 
136
  def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str]]:
137
  file_name = getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf")
138
  logger.info(f"Processing uploaded PDF: {file_name}")
 
144
  if not os.path.exists(pdf_path):
145
  raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
146
 
147
+ image_data = self._pdf_to_images(pdf_path)
148
+ if not image_data:
149
+ raise ValueError("No pages converted from PDF")
150
 
151
+ # Process each page with OCR
152
+ ocr_results = []
153
+ for _, encoded in image_data:
154
+ response = self._call_ocr_api(encoded)
155
+ markdown = self._get_combined_markdown(response)
156
+ ocr_results.append(markdown)
157
+
158
+ image_paths = [path for path, _ in image_data]
159
+ return "\n\n".join(ocr_results), image_paths
160
+ except Exception as e:
161
+ return self._handle_error("uploaded PDF processing", e), []
162
+
163
+ def ocr_pdf_url(self, pdf_url: str) -> Tuple[str, List[str]]:
164
+ logger.info(f"Processing PDF URL: {pdf_url}")
165
+ try:
166
+ file_name = pdf_url.split('/')[-1] or f"pdf_{int(time.time())}.pdf"
167
+ pdf_path = self._save_uploaded_file(pdf_url, file_name)
168
+
169
+ if not os.path.exists(pdf_path):
170
+ raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
171
+
172
+ image_data = self._pdf_to_images(pdf_path)
173
+ if not image_data:
174
+ raise ValueError("No pages converted from PDF")
175
+
176
+ ocr_results = []
177
+ for _, encoded in image_data:
178
+ response = self._call_ocr_api(encoded)
179
+ markdown = self._get_combined_markdown(response)
180
+ ocr_results.append(markdown)
181
+
182
+ image_paths = [path for path, _ in image_data]
183
+ return "\n\n".join(ocr_results), image_paths
184
  except Exception as e:
185
+ return self._handle_error("PDF URL processing", e), []
186
 
187
  def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str]:
188
  file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg")
 
191
  self._check_file_size(image_file)
192
  image_path = self._save_uploaded_file(image_file, file_name)
193
  encoded_image = self._encode_image(image_path)
194
+ response = self._call_ocr_api(encoded_image)
 
 
 
195
  return self._get_combined_markdown(response), image_path
196
  except Exception as e:
197
  return self._handle_error("image processing", e), None
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  @staticmethod
200
  def _get_combined_markdown(response: OCRResponse) -> str:
201
  return "\n\n".join(
 
208
  logger.error(f"Error in {context}: {str(error)}")
209
  return f"**Error in {context}:** {str(error)}"
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def create_interface():
212
  css = """
213
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
 
215
  """
216
 
217
  with gr.Blocks(title="Mistral OCR App", css=css) as demo:
218
+ gr.Markdown("# Mistral OCR App\nUpload images or PDFs, or provide a PDF URL for OCR processing")
219
 
220
  with gr.Row():
221
  api_key = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key")
 
260
 
261
  with gr.Tab("PDF OCR"):
262
  with gr.Row():
263
+ with gr.Column():
264
+ pdf_input = gr.File(
265
+ label=f"Upload PDF (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages)",
266
+ file_types=SUPPORTED_PDF_TYPES
267
+ )
268
+ pdf_url_input = gr.Textbox(
269
+ label="Or Enter PDF URL",
270
+ placeholder="e.g., https://arxiv.org/pdf/2201.04234.pdf"
271
+ )
272
  pdf_gallery = gr.Gallery(label="PDF Pages", height=300)
273
  pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
274
  process_pdf_btn = gr.Button("Process PDF", variant="primary")
275
 
276
+ def process_pdf(processor, pdf_file, pdf_url):
277
+ if not processor:
278
+ return "Please set API key first", []
279
+ if pdf_file:
280
+ return processor.ocr_uploaded_pdf(pdf_file)
281
+ elif pdf_url:
282
+ return processor.ocr_pdf_url(pdf_url)
283
+ return "Please upload a PDF or provide a URL", []
284
 
285
  process_pdf_btn.click(
286
  fn=process_pdf,
287
+ inputs=[processor_state, pdf_input, pdf_url_input],
288
  outputs=[pdf_output, pdf_gallery]
289
  )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return demo
292
 
293
  if __name__ == "__main__":