Spaces:
Running
Running
import os | |
import base64 | |
import gradio as gr | |
from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk | |
from mistralai.models import OCRResponse | |
from pathlib import Path | |
import pycountry | |
import json | |
import logging | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
import tempfile | |
from typing import Union, Dict, List, Optional, Tuple | |
from contextlib import contextmanager | |
import requests | |
import shutil | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
import fitz # PyMuPDF | |
# Constants | |
DEFAULT_LANGUAGE = "English" | |
SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"] | |
SUPPORTED_PDF_TYPES = [".pdf"] | |
TEMP_FILE_EXPIRY = 7200 # 2 hours in seconds | |
UPLOAD_FOLDER = "./uploads" | |
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB | |
MAX_PDF_PAGES = 50 | |
# Configuration | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
class OCRProcessor: | |
def __init__(self, api_key: str): | |
self.api_key = self._validate_api_key(api_key) | |
self.client = Mistral(api_key=self.api_key) | |
self._validate_client() | |
def _validate_api_key(api_key: str) -> str: | |
if not api_key or not isinstance(api_key, str): | |
raise ValueError("Valid API key must be provided") | |
return api_key | |
def _validate_client(self) -> None: | |
try: | |
models = self.client.models.list() | |
if not models: | |
raise ValueError("No models available") | |
except Exception as e: | |
raise ValueError(f"API key validation failed: {str(e)}") | |
def _check_file_size(file_input: Union[str, bytes]) -> None: | |
if isinstance(file_input, str) and os.path.exists(file_input): | |
size = os.path.getsize(file_input) | |
elif hasattr(file_input, 'read'): | |
size = len(file_input.read()) | |
file_input.seek(0) # Reset file pointer | |
else: | |
size = len(file_input) | |
if size > MAX_FILE_SIZE: | |
raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit") | |
def _encode_image(image_path: str) -> Optional[str]: | |
try: | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
except Exception as e: | |
logger.error(f"Error encoding image {image_path}: {str(e)}") | |
return None | |
def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str: | |
clean_filename = os.path.basename(filename).replace(os.sep, "_") | |
file_path = os.path.join(UPLOAD_FOLDER, f"{int(time.time())}_{clean_filename}") | |
try: | |
if isinstance(file_input, str) and file_input.startswith("http"): | |
response = requests.get(file_input, timeout=10) | |
response.raise_for_status() | |
with open(file_path, 'wb') as f: | |
f.write(response.content) | |
elif isinstance(file_input, str) and os.path.exists(file_input): | |
shutil.copy2(file_input, file_path) | |
else: | |
with open(file_path, 'wb') as f: | |
if hasattr(file_input, 'read'): | |
shutil.copyfileobj(file_input, f) | |
else: | |
f.write(file_input) | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Failed to save file at {file_path}") | |
return file_path | |
except Exception as e: | |
logger.error(f"Error saving file {filename}: {str(e)}") | |
raise | |
def _pdf_to_images(pdf_path: str) -> List[str]: | |
try: | |
pdf_document = fitz.open(pdf_path) | |
if pdf_document.page_count > MAX_PDF_PAGES: | |
pdf_document.close() | |
raise ValueError(f"PDF exceeds maximum page limit of {MAX_PDF_PAGES}") | |
with ThreadPoolExecutor() as executor: | |
image_paths = list(executor.map( | |
lambda i: OCRProcessor._convert_page(pdf_path, i), | |
range(pdf_document.page_count) | |
)) | |
pdf_document.close() | |
return [path for path in image_paths if path] | |
except Exception as e: | |
logger.error(f"Error converting PDF to images: {str(e)}") | |
return [] | |
def _convert_page(pdf_path: str, page_num: int) -> Optional[str]: | |
try: | |
pdf_document = fitz.open(pdf_path) | |
page = pdf_document[page_num] | |
pix = page.get_pixmap(dpi=150) | |
image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}_{int(time.time())}.png") | |
pix.save(image_path) | |
pdf_document.close() | |
return image_path | |
except Exception as e: | |
logger.error(f"Error converting page {page_num}: {str(e)}") | |
return None | |
def _call_ocr_api(self, document: Union[DocumentURLChunk, ImageURLChunk]) -> OCRResponse: | |
return self.client.ocr.process( | |
model="mistral-ocr-latest", | |
document=document, | |
include_image_base64=True | |
) | |
def _call_chat_complete(self, model: str, messages: List[Dict], **kwargs) -> Dict: | |
return self.client.chat.complete(model=model, messages=messages, **kwargs) | |
def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> Tuple[str, List[str]]: | |
file_name = getattr(pdf_file, 'name', f"pdf_{int(time.time())}.pdf") | |
logger.info(f"Processing uploaded PDF: {file_name}") | |
try: | |
self._check_file_size(pdf_file) | |
pdf_path = self._save_uploaded_file(pdf_file, file_name) | |
logger.info(f"Saved PDF to: {pdf_path}") | |
if not os.path.exists(pdf_path): | |
raise FileNotFoundError(f"Saved PDF not found at: {pdf_path}") | |
image_paths = self._pdf_to_images(pdf_path) | |
with open(pdf_path, "rb") as f: | |
uploaded_file = self.client.files.upload( | |
file={"file_name": file_name, "content": f}, | |
purpose="ocr" | |
) | |
signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY) | |
response = self._call_ocr_api(DocumentURLChunk(document_url=signed_url.url)) | |
return self._get_combined_markdown(response), image_paths | |
except Exception as e: | |
return self._handle_error("PDF processing", e), [] | |
def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> Tuple[str, str]: | |
file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg") | |
logger.info(f"Processing uploaded image: {file_name}") | |
try: | |
self._check_file_size(image_file) | |
image_path = self._save_uploaded_file(image_file, file_name) | |
encoded_image = self._encode_image(image_path) | |
if not encoded_image: | |
raise ValueError("Failed to encode image") | |
base64_url = f"data:image/jpeg;base64,{encoded_image}" | |
response = self._call_ocr_api(ImageURLChunk(image_url=base64_url)) | |
return self._get_combined_markdown(response), image_path | |
except Exception as e: | |
return self._handle_error("image processing", e), None | |
def document_understanding(self, doc_url: str, question: str) -> str: | |
try: | |
messages = [{"role": "user", "content": [ | |
TextChunk(text=question), | |
DocumentURLChunk(document_url=doc_url) | |
]}] | |
response = self._call_chat_complete( | |
model="mistral-small-latest", | |
messages=messages, | |
temperature=0.1 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return self._handle_error("document understanding", e) | |
def structured_ocr(self, image_file: Union[str, bytes]) -> Tuple[str, str]: | |
file_name = getattr(image_file, 'name', f"image_{int(time.time())}.jpg") | |
try: | |
self._check_file_size(image_file) | |
image_path = self._save_uploaded_file(image_file, file_name) | |
encoded_image = self._encode_image(image_path) | |
if not encoded_image: | |
raise ValueError("Failed to encode image") | |
base64_url = f"data:image/jpeg;base64,{encoded_image}" | |
ocr_response = self._call_ocr_api(ImageURLChunk(image_url=base64_url)) | |
markdown = self._get_combined_markdown(ocr_response) | |
chat_response = self._call_chat_complete( | |
model="pixtral-12b-latest", | |
messages=[{ | |
"role": "user", | |
"content": [ | |
ImageURLChunk(image_url=base64_url), | |
TextChunk(text=( | |
f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>.\n" | |
"Convert this into a structured JSON response with file_name, topics, languages, and ocr_contents fields" | |
)) | |
] | |
}], | |
response_format={"type": "json_object"}, | |
temperature=0.1 | |
) | |
return self._format_structured_response(image_path, json.loads(chat_response.choices[0].message.content)), image_path | |
except Exception as e: | |
return self._handle_error("structured OCR", e), None | |
def _get_combined_markdown(response: OCRResponse) -> str: | |
return "\n\n".join( | |
page.markdown for page in response.pages | |
if page.markdown.strip() | |
) or "No text detected" | |
def _handle_error(context: str, error: Exception) -> str: | |
logger.error(f"Error in {context}: {str(error)}") | |
return f"**Error in {context}:** {str(error)}" | |
def _format_structured_response(file_path: str, content: Dict) -> str: | |
languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')} | |
content_languages = content.get("languages", [DEFAULT_LANGUAGE]) | |
valid_langs = [l for l in content_languages if l in languages.values()] or [DEFAULT_LANGUAGE] | |
response = { | |
"file_name": Path(file_path).name, | |
"topics": content.get("topics", []), | |
"languages": valid_langs, | |
"ocr_contents": content.get("ocr_contents", {}) | |
} | |
return f"```json\n{json.dumps(response, indent=2, ensure_ascii=False)}\n```" | |
def create_interface(): | |
css = """ | |
.output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;} | |
.status {color: #666; font-style: italic;} | |
""" | |
with gr.Blocks(title="Mistral OCR App", css=css) as demo: | |
gr.Markdown("# Mistral OCR App\nUpload images or PDFs for OCR processing") | |
with gr.Row(): | |
api_key = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key") | |
set_key_btn = gr.Button("Set API Key", variant="primary") | |
processor_state = gr.State() | |
status = gr.Markdown("Please enter API key", elem_classes="status") | |
def init_processor(key): | |
try: | |
processor = OCRProcessor(key) | |
return processor, "✅ API key validated successfully" | |
except Exception as e: | |
return None, f"❌ Error: {str(e)}" | |
set_key_btn.click( | |
fn=init_processor, | |
inputs=api_key, | |
outputs=[processor_state, status] | |
) | |
with gr.Tab("Image OCR"): | |
with gr.Row(): | |
image_input = gr.File( | |
label=f"Upload Image (max {MAX_FILE_SIZE/1024/1024}MB)", | |
file_types=SUPPORTED_IMAGE_TYPES | |
) | |
image_preview = gr.Image(label="Preview", height=300) | |
image_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown") | |
process_image_btn = gr.Button("Process Image", variant="primary") | |
def process_image(processor, image): | |
if not processor or not image: | |
return "Please set API key and upload an image", None | |
return processor.ocr_uploaded_image(image) | |
process_image_btn.click( | |
fn=process_image, | |
inputs=[processor_state, image_input], | |
outputs=[image_output, image_preview] | |
) | |
with gr.Tab("PDF OCR"): | |
with gr.Row(): | |
pdf_input = gr.File( | |
label=f"Upload PDF (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages)", | |
file_types=SUPPORTED_PDF_TYPES | |
) | |
pdf_gallery = gr.Gallery(label="PDF Pages", height=300) | |
pdf_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown") | |
process_pdf_btn = gr.Button("Process PDF", variant="primary") | |
def process_pdf(processor, pdf): | |
if not processor or not pdf: | |
return "Please set API key and upload a PDF", [] | |
return processor.ocr_uploaded_pdf(pdf) | |
process_pdf_btn.click( | |
fn=process_pdf, | |
inputs=[processor_state, pdf_input], | |
outputs=[pdf_output, pdf_gallery] | |
) | |
with gr.Tab("Structured OCR"): | |
structured_input = gr.File( | |
label=f"Upload Image for Structured OCR (max {MAX_FILE_SIZE/1024/1024}MB)", | |
file_types=SUPPORTED_IMAGE_TYPES | |
) | |
structured_output = gr.Markdown(label="Structured Result", elem_classes="output-markdown") | |
structured_preview = gr.Image(label="Preview", height=300) | |
process_structured_btn = gr.Button("Process Structured OCR", variant="primary") | |
def process_structured(processor, image): | |
if not processor or not image: | |
return "Please set API key and upload an image", None | |
return processor.structured_ocr(image) | |
process_structured_btn.click( | |
fn=process_structured, | |
inputs=[processor_state, structured_input], | |
outputs=[structured_output, structured_preview] | |
) | |
return demo | |
if __name__ == "__main__": | |
os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S') | |
print(f"===== Application Startup at {os.environ['START_TIME']} =====") | |
create_interface().launch( | |
share=True, | |
debug=True, | |
) |