Spaces:
Running
Running
few shots + reasoning
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import base64
|
3 |
import gradio as gr
|
4 |
import json
|
|
|
5 |
from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
|
6 |
from mistralai.models import OCRResponse
|
7 |
from typing import Union, List, Tuple, Dict
|
@@ -55,7 +56,6 @@ class StructuredOCR(BaseModel):
|
|
55 |
ocr_contents: dict
|
56 |
|
57 |
def model_dump_json(self, **kwargs):
|
58 |
-
# Custom JSON serialization to handle Language enums
|
59 |
data = self.model_dump(exclude_unset=True, by_alias=True, mode='json')
|
60 |
for key, value in data.items():
|
61 |
if isinstance(value, list) and all(isinstance(item, Language) for item in value):
|
@@ -66,6 +66,7 @@ class OCRProcessor:
|
|
66 |
def __init__(self, api_key: str):
|
67 |
if not api_key or not isinstance(api_key, str):
|
68 |
raise ValueError("Valid API key must be provided")
|
|
|
69 |
self.client = Mistral(api_key=api_key)
|
70 |
self._validate_client()
|
71 |
|
@@ -74,6 +75,7 @@ class OCRProcessor:
|
|
74 |
models = self.client.models.list()
|
75 |
if not models:
|
76 |
raise ValueError("No models available")
|
|
|
77 |
except Exception as e:
|
78 |
raise ValueError(f"API key validation failed: {str(e)}")
|
79 |
|
@@ -170,11 +172,11 @@ class OCRProcessor:
|
|
170 |
|
171 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
172 |
def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
|
|
|
173 |
if not isinstance(encoded_image, str):
|
174 |
raise TypeError(f"Expected encoded_image to be a string, got {type(encoded_image)}")
|
175 |
base64_url = f"data:image/png;base64,{encoded_image}"
|
176 |
try:
|
177 |
-
logger.info("Calling OCR API")
|
178 |
response = self.client.ocr.process(
|
179 |
document=ImageURLChunk(image_url=base64_url),
|
180 |
model="mistral-ocr-latest",
|
@@ -190,26 +192,19 @@ class OCRProcessor:
|
|
190 |
except Exception as log_err:
|
191 |
logger.warning(f"Failed to log raw OCR response: {str(log_err)}")
|
192 |
return response
|
193 |
-
except (ConnectionError, TimeoutError, socket.error) as e:
|
194 |
-
logger.error(f"Network error during OCR API call: {str(e)}")
|
195 |
-
raise
|
196 |
-
except TypeError as e:
|
197 |
-
logger.error(f"TypeError in OCR API call: {str(e)}", exc_info=True)
|
198 |
-
raise
|
199 |
except Exception as e:
|
200 |
-
logger.error(f"
|
201 |
raise
|
202 |
|
203 |
def _process_pdf_with_ocr(self, pdf_path: str) -> Tuple[str, List[str], List[Dict]]:
|
204 |
try:
|
205 |
-
|
206 |
uploaded_file = self.client.files.upload(
|
207 |
file={"file_name": Path(pdf_path).stem, "content": Path(pdf_path).read_bytes()},
|
208 |
purpose="ocr",
|
209 |
)
|
210 |
signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=1).url
|
211 |
|
212 |
-
# Process with OCR
|
213 |
ocr_response = self.client.ocr.process(
|
214 |
document=DocumentURLChunk(document_url=signed_url),
|
215 |
model="mistral-ocr-latest",
|
@@ -217,7 +212,6 @@ class OCRProcessor:
|
|
217 |
)
|
218 |
markdown, base64_images = self._get_combined_markdown(ocr_response)
|
219 |
json_results = self._convert_to_structured_json(markdown, pdf_path)
|
220 |
-
# Fallback to local images if OCR images are missing
|
221 |
image_paths = []
|
222 |
if not any(page.images for page in ocr_response.pages):
|
223 |
logger.warning("No images found in OCR response; using local images")
|
@@ -251,7 +245,8 @@ class OCRProcessor:
|
|
251 |
image_data = {}
|
252 |
for img in page.images:
|
253 |
if img.image_base64:
|
254 |
-
|
|
|
255 |
image_data[img.id] = base64_url
|
256 |
base64_images.append(base64_url)
|
257 |
logger.info(f"Base64 image {img.id} length: {len(img.image_base64)}")
|
@@ -395,17 +390,12 @@ def create_interface():
|
|
395 |
markdown, image_paths, json_data = processor.ocr_pdf_url(pdf_url)
|
396 |
else:
|
397 |
return "Please upload a PDF or provide a valid URL", [], {}
|
398 |
-
|
399 |
-
image_components = []
|
400 |
-
for path in image_paths:
|
401 |
-
if path and os.path.exists(path):
|
402 |
-
image_components.append(gr.Image(path, label=f"Page Image"))
|
403 |
-
return markdown, image_paths, json_data, gr.Column(*image_components) if image_components else gr.Markdown("No images available")
|
404 |
|
405 |
process_pdf_btn.click(
|
406 |
fn=process_pdf,
|
407 |
inputs=[processor_state, pdf_input, pdf_url_input],
|
408 |
-
outputs=[pdf_output, pdf_gallery, pdf_json_output
|
409 |
)
|
410 |
|
411 |
return demo
|
|
|
2 |
import base64
|
3 |
import gradio as gr
|
4 |
import json
|
5 |
+
import re # Added to fix NameError
|
6 |
from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
|
7 |
from mistralai.models import OCRResponse
|
8 |
from typing import Union, List, Tuple, Dict
|
|
|
56 |
ocr_contents: dict
|
57 |
|
58 |
def model_dump_json(self, **kwargs):
|
|
|
59 |
data = self.model_dump(exclude_unset=True, by_alias=True, mode='json')
|
60 |
for key, value in data.items():
|
61 |
if isinstance(value, list) and all(isinstance(item, Language) for item in value):
|
|
|
66 |
def __init__(self, api_key: str):
|
67 |
if not api_key or not isinstance(api_key, str):
|
68 |
raise ValueError("Valid API key must be provided")
|
69 |
+
self.api_key = api_key
|
70 |
self.client = Mistral(api_key=api_key)
|
71 |
self._validate_client()
|
72 |
|
|
|
75 |
models = self.client.models.list()
|
76 |
if not models:
|
77 |
raise ValueError("No models available")
|
78 |
+
logger.info("API key validated successfully")
|
79 |
except Exception as e:
|
80 |
raise ValueError(f"API key validation failed: {str(e)}")
|
81 |
|
|
|
172 |
|
173 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
174 |
def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
|
175 |
+
logger.info(f"Calling OCR API with API key: {self.api_key[:4]}...") # Log partial key for debugging
|
176 |
if not isinstance(encoded_image, str):
|
177 |
raise TypeError(f"Expected encoded_image to be a string, got {type(encoded_image)}")
|
178 |
base64_url = f"data:image/png;base64,{encoded_image}"
|
179 |
try:
|
|
|
180 |
response = self.client.ocr.process(
|
181 |
document=ImageURLChunk(image_url=base64_url),
|
182 |
model="mistral-ocr-latest",
|
|
|
192 |
except Exception as log_err:
|
193 |
logger.warning(f"Failed to log raw OCR response: {str(log_err)}")
|
194 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
except Exception as e:
|
196 |
+
logger.error(f"OCR API error: {str(e)}", exc_info=True)
|
197 |
raise
|
198 |
|
199 |
def _process_pdf_with_ocr(self, pdf_path: str) -> Tuple[str, List[str], List[Dict]]:
|
200 |
try:
|
201 |
+
logger.info(f"Processing PDF with API key: {self.api_key[:4]}...")
|
202 |
uploaded_file = self.client.files.upload(
|
203 |
file={"file_name": Path(pdf_path).stem, "content": Path(pdf_path).read_bytes()},
|
204 |
purpose="ocr",
|
205 |
)
|
206 |
signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=1).url
|
207 |
|
|
|
208 |
ocr_response = self.client.ocr.process(
|
209 |
document=DocumentURLChunk(document_url=signed_url),
|
210 |
model="mistral-ocr-latest",
|
|
|
212 |
)
|
213 |
markdown, base64_images = self._get_combined_markdown(ocr_response)
|
214 |
json_results = self._convert_to_structured_json(markdown, pdf_path)
|
|
|
215 |
image_paths = []
|
216 |
if not any(page.images for page in ocr_response.pages):
|
217 |
logger.warning("No images found in OCR response; using local images")
|
|
|
245 |
image_data = {}
|
246 |
for img in page.images:
|
247 |
if img.image_base64:
|
248 |
+
# Use correct MIME type based on image format (assuming JPEG from logs)
|
249 |
+
base64_url = f"data:image/jpeg;base64,{img.image_base64}"
|
250 |
image_data[img.id] = base64_url
|
251 |
base64_images.append(base64_url)
|
252 |
logger.info(f"Base64 image {img.id} length: {len(img.image_base64)}")
|
|
|
390 |
markdown, image_paths, json_data = processor.ocr_pdf_url(pdf_url)
|
391 |
else:
|
392 |
return "Please upload a PDF or provide a valid URL", [], {}
|
393 |
+
return markdown, image_paths, json_data
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
process_pdf_btn.click(
|
396 |
fn=process_pdf,
|
397 |
inputs=[processor_state, pdf_input, pdf_url_input],
|
398 |
+
outputs=[pdf_output, pdf_gallery, pdf_json_output]
|
399 |
)
|
400 |
|
401 |
return demo
|