Svngoku commited on
Commit
d5d1102
·
verified ·
1 Parent(s): be2e6ae

Few updates

Browse files
Files changed (1) hide show
  1. app.py +183 -112
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import base64
3
  import gradio as gr
4
- from mistralai import Mistral, ImageURLChunk
 
5
  from mistralai.models import OCRResponse
6
- from typing import Union, List, Tuple
7
  import requests
8
  import shutil
9
  import time
@@ -13,6 +14,11 @@ from tenacity import retry, stop_after_attempt, wait_exponential
13
  from concurrent.futures import ThreadPoolExecutor
14
  import socket
15
  from requests.exceptions import ConnectionError, Timeout
 
 
 
 
 
16
 
17
  # Constants
18
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
@@ -30,6 +36,32 @@ logging.basicConfig(
30
  )
31
  logger = logging.getLogger(__name__)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class OCRProcessor:
34
  def __init__(self, api_key: str):
35
  if not api_key or not isinstance(api_key, str):
@@ -91,10 +123,12 @@ class OCRProcessor:
91
  def _encode_image(image_path: str) -> str:
92
  try:
93
  with open(image_path, "rb") as image_file:
94
- return base64.b64encode(image_file.read()).decode('utf-8')
 
 
95
  except Exception as e:
96
  logger.error(f"Error encoding image {image_path}: {str(e)}")
97
- raise ValueError("Failed to encode image")
98
 
99
  @staticmethod
100
  def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
@@ -110,10 +144,14 @@ class OCRProcessor:
110
  range(pdf_document.page_count)
111
  ))
112
  pdf_document.close()
113
- return [data for data in image_data if data]
 
 
 
 
114
  except Exception as e:
115
  logger.error(f"Error converting PDF to images: {str(e)}")
116
- return []
117
 
118
  @staticmethod
119
  def _convert_page(pdf_path: str, page_num: int) -> Tuple[str, str]:
@@ -132,128 +170,151 @@ class OCRProcessor:
132
 
133
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
134
  def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
 
 
135
  base64_url = f"data:image/png;base64,{encoded_image}"
136
  try:
137
  logger.info("Calling OCR API")
138
  response = self.client.ocr.process(
139
- model="mistral-ocr-latest",
140
  document=ImageURLChunk(image_url=base64_url),
 
141
  include_image_base64=True
142
  )
143
  logger.info("OCR API call successful")
 
 
 
 
 
 
 
 
144
  return response
145
- except (ConnectionError, Timeout, socket.error) as e:
146
  logger.error(f"Network error during OCR API call: {str(e)}")
147
  raise
148
-
149
- def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str]]:
150
- file_name = getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf")
151
- logger.info(f"Processing uploaded PDF: {file_name}")
152
- try:
153
- self._check_file_size(pdf_file)
154
- pdf_path = self._save_uploaded_file(pdf_file, file_name)
155
-
156
- if not os.path.exists(pdf_path):
157
- raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
158
-
159
- image_data = self._pdf_to_images(pdf_path)
160
- if not image_data:
161
- raise ValueError("No pages converted from PDF")
162
-
163
- ocr_results = []
164
- image_paths = [path for path, _ in image_data]
165
- for i, (_, encoded) in enumerate(image_data):
166
- response = self._call_ocr_api(encoded)
167
- markdown_with_images = self._get_combined_markdown_with_images(response, image_paths, i)
168
- ocr_results.append(markdown_with_images)
169
-
170
- return "\n\n".join(ocr_results), image_paths
171
  except Exception as e:
172
- return self._handle_error("uploaded PDF processing", e), []
 
173
 
174
- def ocr_pdf_url(self, pdf_url: str) -> Tuple[str, List[str]]:
175
- logger.info(f"Processing PDF URL: {pdf_url}")
176
  try:
177
- file_name = pdf_url.split('/')[-1] or f"pdf_{int(time.time())}.pdf"
178
- pdf_path = self._save_uploaded_file(pdf_url, file_name)
179
-
180
- if not os.path.exists(pdf_path):
181
- raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}")
182
-
183
- image_data = self._pdf_to_images(pdf_path)
184
- if not image_data:
185
- raise ValueError("No pages converted from PDF")
186
-
187
- ocr_results = []
188
- image_paths = [path for path, _ in image_data]
189
- for i, (_, encoded) in enumerate(image_data):
190
- response = self._call_ocr_api(encoded)
191
- markdown_with_images = self._get_combined_markdown_with_images(response, image_paths, i)
192
- ocr_results.append(markdown_with_images)
193
-
194
- return "\n\n".join(ocr_results), image_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  except Exception as e:
196
- return self._handle_error("PDF URL processing", e), []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str]:
199
- file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg")
200
- logger.info(f"Processing uploaded image: {file_name}")
201
  try:
202
- self._check_file_size(image_file)
203
- image_path = self._save_uploaded_file(image_file, file_name)
204
- encoded_image = self._encode_image(image_path)
205
- response = self._call_ocr_api(encoded_image)
206
- return self._get_combined_markdown_with_images(response), image_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  except Exception as e:
208
- return self._handle_error("image processing", e), None
 
209
 
210
- @staticmethod
211
- def _get_combined_markdown_with_images(response: OCRResponse, image_paths: List[str] = None, page_index: int = None) -> str:
212
- markdown_parts = []
213
- for i, page in enumerate(response.pages):
214
- if page.markdown.strip():
215
- markdown = page.markdown
216
- logger.info(f"Page {i} markdown: {markdown}")
217
- if hasattr(page, 'images') and page.images:
218
- logger.info(f"Found {len(page.images)} images in page {i}")
219
- for img in page.images:
220
- if img.image_base64:
221
- logger.info(f"Replacing image {img.id} with base64")
222
- markdown = markdown.replace(
223
- f"![{img.id}]({img.id})",
224
- f"![{img.id}](data:image/png;base64,{img.image_base64})"
225
- )
226
- else:
227
- logger.warning(f"No base64 data for image {img.id}")
228
- if image_paths and page_index is not None and page_index < len(image_paths):
229
- local_encoded = OCRProcessor._encode_image(image_paths[page_index])
230
- markdown = markdown.replace(
231
- f"![{img.id}]({img.id})",
232
- f"![{img.id}](data:image/png;base64,{local_encoded})"
233
- )
234
- else:
235
- logger.warning(f"No images found in page {i}")
236
- # Replace known placeholders or append the local image
237
- if image_paths and page_index is not None and page_index < len(image_paths):
238
- local_encoded = OCRProcessor._encode_image(image_paths[page_index])
239
- # Replace placeholders like img-0.jpeg
240
- placeholder = f"img-{i}.jpeg"
241
- if placeholder in markdown:
242
- markdown = markdown.replace(
243
- placeholder,
244
- f"![Page {i} Image](data:image/png;base64,{local_encoded})"
245
- )
246
- else:
247
- # Append the image if no placeholder is found
248
- markdown += f"\n\n![Page {i} Image](data:image/png;base64,{local_encoded})"
249
- markdown_parts.append(markdown)
250
- return "\n\n".join(markdown_parts) or "No text or images detected"
251
 
252
  @staticmethod
253
  def _handle_error(context: str, error: Exception) -> str:
254
- logger.error(f"Error in {context}: {str(error)}")
255
  return f"**Error in {context}:** {str(error)}"
256
 
 
 
 
 
 
257
  def create_interface():
258
  css = """
259
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
@@ -291,17 +352,19 @@ def create_interface():
291
  )
292
  image_preview = gr.Image(label="Preview", height=300)
293
  image_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
 
294
  process_image_btn = gr.Button("Process Image", variant="primary")
295
 
296
  def process_image(processor, image):
297
  if not processor or not image:
298
- return "Please set API key and upload an image", None
299
- return processor.ocr_uploaded_image(image)
 
300
 
301
  process_image_btn.click(
302
  fn=process_image,
303
  inputs=[processor_state, image_input],
304
- outputs=[image_output, image_preview]
305
  )
306
 
307
  with gr.Tab("PDF OCR"):
@@ -317,24 +380,32 @@ def create_interface():
317
  )
318
  pdf_gallery = gr.Gallery(label="PDF Pages", height=300)
319
  pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
 
320
  process_pdf_btn = gr.Button("Process PDF", variant="primary")
321
 
322
  def process_pdf(processor, pdf_file, pdf_url):
323
  if not processor:
324
- return "Please set API key first", []
325
  logger.info(f"Received inputs - PDF file: {pdf_file}, PDF URL: {pdf_url}")
326
  if pdf_file is not None and hasattr(pdf_file, 'name'):
327
  logger.info(f"Processing as uploaded PDF: {pdf_file.name}")
328
- return processor.ocr_uploaded_pdf(pdf_file)
329
  elif pdf_url and pdf_url.strip():
330
  logger.info(f"Processing as PDF URL: {pdf_url}")
331
- return processor.ocr_pdf_url(pdf_url)
332
- return "Please upload a PDF or provide a valid URL", []
 
 
 
 
 
 
 
333
 
334
  process_pdf_btn.click(
335
  fn=process_pdf,
336
  inputs=[processor_state, pdf_input, pdf_url_input],
337
- outputs=[pdf_output, pdf_gallery]
338
  )
339
 
340
  return demo
 
1
  import os
2
  import base64
3
  import gradio as gr
4
+ import json
5
+ from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
6
  from mistralai.models import OCRResponse
7
+ from typing import Union, List, Tuple, Dict
8
  import requests
9
  import shutil
10
  import time
 
14
  from concurrent.futures import ThreadPoolExecutor
15
  import socket
16
  from requests.exceptions import ConnectionError, Timeout
17
+ from pathlib import Path
18
+ from pydantic import BaseModel
19
+ import pycountry
20
+ from enum import Enum
21
+ from PIL import Image
22
 
23
  # Constants
24
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
 
36
  )
37
  logger = logging.getLogger(__name__)
38
 
39
+ # Language Enum for StructuredOCR
40
+ languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
41
+
42
+ class LanguageMeta(Enum.__class__):
43
+ def __new__(metacls, cls, bases, classdict):
44
+ for code, name in languages.items():
45
+ classdict[name.upper().replace(' ', '_')] = name
46
+ return super().__new__(metacls, cls, bases, classdict)
47
+
48
+ class Language(Enum, metaclass=LanguageMeta):
49
+ pass
50
+
51
+ class StructuredOCR(BaseModel):
52
+ file_name: str
53
+ topics: list[str]
54
+ languages: list[Language]
55
+ ocr_contents: dict
56
+
57
+ def model_dump_json(self, **kwargs):
58
+ # Custom JSON serialization to handle Language enums
59
+ data = self.model_dump(exclude_unset=True, by_alias=True, mode='json')
60
+ for key, value in data.items():
61
+ if isinstance(value, list) and all(isinstance(item, Language) for item in value):
62
+ data[key] = [item.value for item in value]
63
+ return json.dumps(data, indent=4)
64
+
65
  class OCRProcessor:
66
  def __init__(self, api_key: str):
67
  if not api_key or not isinstance(api_key, str):
 
123
  def _encode_image(image_path: str) -> str:
124
  try:
125
  with open(image_path, "rb") as image_file:
126
+ encoded = base64.b64encode(image_file.read()).decode('utf-8')
127
+ logger.info(f"Encoded image {image_path} to base64 (length: {len(encoded)})")
128
+ return encoded
129
  except Exception as e:
130
  logger.error(f"Error encoding image {image_path}: {str(e)}")
131
+ raise ValueError(f"Failed to encode image: {str(e)}")
132
 
133
  @staticmethod
134
  def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
 
144
  range(pdf_document.page_count)
145
  ))
146
  pdf_document.close()
147
+ valid_image_data = [(path, encoded) for path, encoded in image_data if path and encoded]
148
+ if not valid_image_data:
149
+ raise ValueError("No valid pages converted from PDF")
150
+ logger.info(f"Converted {len(valid_image_data)} pages to images")
151
+ return valid_image_data
152
  except Exception as e:
153
  logger.error(f"Error converting PDF to images: {str(e)}")
154
+ raise
155
 
156
  @staticmethod
157
  def _convert_page(pdf_path: str, page_num: int) -> Tuple[str, str]:
 
170
 
171
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
172
  def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
173
+ if not isinstance(encoded_image, str):
174
+ raise TypeError(f"Expected encoded_image to be a string, got {type(encoded_image)}")
175
  base64_url = f"data:image/png;base64,{encoded_image}"
176
  try:
177
  logger.info("Calling OCR API")
178
  response = self.client.ocr.process(
 
179
  document=ImageURLChunk(image_url=base64_url),
180
+ model="mistral-ocr-latest",
181
  include_image_base64=True
182
  )
183
  logger.info("OCR API call successful")
184
+ try:
185
+ if hasattr(response, 'model_dump_json'):
186
+ response_dict = json.loads(response.model_dump_json())
187
+ else:
188
+ response_dict = {k: v for k, v in response.__dict__.items() if isinstance(v, (str, int, float, list, dict))}
189
+ logger.info(f"Raw OCR response: {json.dumps(response_dict, default=str, indent=4)}")
190
+ except Exception as log_err:
191
+ logger.warning(f"Failed to log raw OCR response: {str(log_err)}")
192
  return response
193
+ except (ConnectionError, TimeoutError, socket.error) as e:
194
  logger.error(f"Network error during OCR API call: {str(e)}")
195
  raise
196
+ except TypeError as e:
197
+ logger.error(f"TypeError in OCR API call: {str(e)}", exc_info=True)
198
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  except Exception as e:
200
+ logger.error(f"Unexpected error in OCR API call: {str(e)}", exc_info=True)
201
+ raise
202
 
203
+ def _process_pdf_with_ocr(self, pdf_path: str) -> Tuple[str, List[str], List[Dict]]:
 
204
  try:
205
+ # Upload PDF and get signed URL
206
+ uploaded_file = self.client.files.upload(
207
+ file={"file_name": Path(pdf_path).stem, "content": Path(pdf_path).read_bytes()},
208
+ purpose="ocr",
209
+ )
210
+ signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=1).url
211
+
212
+ # Process with OCR
213
+ ocr_response = self.client.ocr.process(
214
+ document=DocumentURLChunk(document_url=signed_url),
215
+ model="mistral-ocr-latest",
216
+ include_image_base64=True
217
+ )
218
+ markdown, base64_images = self._get_combined_markdown(ocr_response)
219
+ json_results = self._convert_to_structured_json(markdown, pdf_path)
220
+ # Fallback to local images if OCR images are missing
221
+ image_paths = []
222
+ if not any(page.images for page in ocr_response.pages):
223
+ logger.warning("No images found in OCR response; using local images")
224
+ image_data = self._pdf_to_images(pdf_path)
225
+ image_paths = [path for path, _ in image_data]
226
+ else:
227
+ image_paths = [os.path.join(UPLOAD_FOLDER, f"ocr_page_{i}.png") for i in range(len(ocr_response.pages))]
228
+ for i, base64_img in enumerate(base64_images):
229
+ if base64_img:
230
+ try:
231
+ img_data = base64.b64decode(base64_img.split(',')[1])
232
+ with open(image_paths[i], "wb") as f:
233
+ f.write(img_data)
234
+ if os.path.exists(image_paths[i]):
235
+ logger.info(f"Image {image_paths[i]} saved and exists")
236
+ else:
237
+ logger.error(f"Image {image_paths[i]} saved but does not exist")
238
+ except Exception as e:
239
+ logger.error(f"Error saving image {i}: {str(e)}")
240
+ image_paths[i] = None
241
+ image_paths = [path for path in image_paths if path and os.path.exists(path)]
242
+ logger.info(f"Final image paths: {image_paths}")
243
+ return markdown, image_paths, json_results
244
  except Exception as e:
245
+ return self._handle_error("PDF OCR processing", e), [], []
246
+
247
+ def _get_combined_markdown(self, ocr_response: OCRResponse) -> Tuple[str, List[str]]:
248
+ markdowns = []
249
+ base64_images = []
250
+ for i, page in enumerate(ocr_response.pages):
251
+ image_data = {}
252
+ for img in page.images:
253
+ if img.image_base64:
254
+ base64_url = f"data:image/png;base64,{img.image_base64}"
255
+ image_data[img.id] = base64_url
256
+ base64_images.append(base64_url)
257
+ logger.info(f"Base64 image {img.id} length: {len(img.image_base64)}")
258
+ else:
259
+ base64_images.append(None)
260
+ markdown = page.markdown or "No text detected"
261
+ markdown = replace_images_in_markdown(markdown, image_data)
262
+ logger.info(f"Page {i} markdown (first 200 chars): {markdown[:200]}...")
263
+ markdowns.append(markdown)
264
+ return "\n\n".join(markdowns), base64_images
265
 
266
+ def _convert_to_structured_json(self, markdown: str, file_path: str) -> List[Dict]:
 
 
267
  try:
268
+ text_only_markdown = re.sub(r'!\[.*?\]\(data:image/[^)]+\)', '', markdown)
269
+ logger.info(f"Text-only markdown length: {len(text_only_markdown)}")
270
+ logger.info(f"Text-only markdown content: {text_only_markdown[:200]}...")
271
+
272
+ chat_response = self.client.chat.parse(
273
+ model="pixtral-12b-latest",
274
+ messages=[
275
+ {
276
+ "role": "user",
277
+ "content": f"Given OCR output from a PDF about African history and artifacts, convert to JSON with file_name, topics (e.g., African Artifacts, Tribal History), languages (e.g., English), and ocr_contents (title and list of items with descriptions and image refs).\n\nOCR Output:\n{text_only_markdown}"
278
+ },
279
+ ],
280
+ response_format=StructuredOCR,
281
+ temperature=0
282
+ )
283
+ structured_result = chat_response.choices[0].message.parsed
284
+ json_str = structured_result.model_dump_json()
285
+ logger.info(f"Structured JSON: {json_str}")
286
+ return [json.loads(json_str)]
287
  except Exception as e:
288
+ logger.error(f"Error converting to structured JSON: {str(e)}", exc_info=True)
289
+ return [{"error": str(e), "file_name": Path(file_path).stem}]
290
 
291
+ def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str], List[Dict]]:
292
+ file_path = self._save_uploaded_file(pdf_file, getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf"))
293
+ return self._process_pdf_with_ocr(file_path)
294
+
295
+ def ocr_pdf_url(self, pdf_url: str) -> Tuple[str, List[str], List[Dict]]:
296
+ file_path = self._save_uploaded_file(pdf_url, pdf_url.split('/')[-1] or f"pdf_{int(time.time())}.pdf")
297
+ return self._process_pdf_with_ocr(file_path)
298
+
299
+ def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str, Dict]:
300
+ file_path = self._save_uploaded_file(image_file, getattr(image_file, 'name', f"image_{int(time.time())}.jpg"))
301
+ encoded_image = self._encode_image(file_path)
302
+ base64_url = f"data:image/png;base64,{encoded_image}"
303
+ response = self._call_ocr_api(encoded_image)
304
+ markdown, base64_images = self._get_combined_markdown(response)
305
+ json_result = self._convert_to_structured_json(markdown, file_path)[0]
306
+ return markdown, file_path, json_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  @staticmethod
309
  def _handle_error(context: str, error: Exception) -> str:
310
+ logger.error(f"Error in {context}: {str(error)}", exc_info=True)
311
  return f"**Error in {context}:** {str(error)}"
312
 
313
+ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
314
+ for img_name, base64_str in images_dict.items():
315
+ markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})")
316
+ return markdown_str
317
+
318
  def create_interface():
319
  css = """
320
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
 
352
  )
353
  image_preview = gr.Image(label="Preview", height=300)
354
  image_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
355
+ image_json_output = gr.JSON(label="Structured JSON Output")
356
  process_image_btn = gr.Button("Process Image", variant="primary")
357
 
358
  def process_image(processor, image):
359
  if not processor or not image:
360
+ return "Please set API key and upload an image", None, {}
361
+ markdown, image_path, json_data = processor.ocr_uploaded_image(image)
362
+ return markdown, image_path, json_data
363
 
364
  process_image_btn.click(
365
  fn=process_image,
366
  inputs=[processor_state, image_input],
367
+ outputs=[image_output, image_preview, image_json_output]
368
  )
369
 
370
  with gr.Tab("PDF OCR"):
 
380
  )
381
  pdf_gallery = gr.Gallery(label="PDF Pages", height=300)
382
  pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
383
+ pdf_json_output = gr.JSON(label="Structured JSON Output")
384
  process_pdf_btn = gr.Button("Process PDF", variant="primary")
385
 
386
  def process_pdf(processor, pdf_file, pdf_url):
387
  if not processor:
388
+ return "Please set API key first", [], {}
389
  logger.info(f"Received inputs - PDF file: {pdf_file}, PDF URL: {pdf_url}")
390
  if pdf_file is not None and hasattr(pdf_file, 'name'):
391
  logger.info(f"Processing as uploaded PDF: {pdf_file.name}")
392
+ markdown, image_paths, json_data = processor.ocr_uploaded_pdf(pdf_file)
393
  elif pdf_url and pdf_url.strip():
394
  logger.info(f"Processing as PDF URL: {pdf_url}")
395
+ markdown, image_paths, json_data = processor.ocr_pdf_url(pdf_url)
396
+ else:
397
+ return "Please upload a PDF or provide a valid URL", [], {}
398
+ # Fallback to display images if markdown rendering fails
399
+ image_components = []
400
+ for path in image_paths:
401
+ if path and os.path.exists(path):
402
+ image_components.append(gr.Image(path, label=f"Page Image"))
403
+ return markdown, image_paths, json_data, gr.Column(*image_components) if image_components else gr.Markdown("No images available")
404
 
405
  process_pdf_btn.click(
406
  fn=process_pdf,
407
  inputs=[processor_state, pdf_input, pdf_url_input],
408
+ outputs=[pdf_output, pdf_gallery, pdf_json_output, gr.Column()]
409
  )
410
 
411
  return demo