Svngoku commited on
Commit
0709a75
·
verified ·
1 Parent(s): 005a056

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -24
app.py CHANGED
@@ -3,16 +3,16 @@ import base64
3
  import gradio as gr
4
  from mistralai import Mistral
5
  from mistralai.models import OCRResponse
 
6
  from pathlib import Path
7
  from pydantic import BaseModel
8
  import pycountry
9
  import json
10
  import logging
11
- from tenacity import retry, stop_after_attempt, wait_fixed
12
  import tempfile
13
  from typing import Union, Dict, List
14
  from contextlib import contextmanager
15
- import requests
16
 
17
  # Constants
18
  DEFAULT_LANGUAGE = "English"
@@ -32,7 +32,7 @@ class OCRProcessor:
32
  self.client = Mistral(api_key=self.api_key)
33
  try:
34
  self.client.models.list() # Validate API key
35
- except Exception as e:
36
  raise ValueError(f"Invalid API key: {str(e)}")
37
 
38
  @staticmethod
@@ -52,33 +52,26 @@ class OCRProcessor:
52
  if os.path.exists(temp_file.name):
53
  os.unlink(temp_file.name)
54
 
55
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
56
  def _call_ocr_api(self, document: Dict) -> OCRResponse:
57
  try:
58
  return self.client.ocr.process(model="mistral-ocr-latest", document=document)
59
- except Exception as e:
60
  logger.error(f"OCR API call failed: {str(e)}")
61
  raise
62
 
63
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
64
  def _call_chat_complete(self, model: str, messages: List[Dict], **kwargs) -> Dict:
65
  try:
66
  return self.client.chat.complete(model=model, messages=messages, **kwargs)
67
- except Exception as e:
68
  logger.error(f"Chat complete API call failed: {str(e)}")
69
  raise
70
 
71
  def _get_file_content(self, file_input: Union[str, bytes]) -> bytes:
72
  if isinstance(file_input, str):
73
- if file_input.startswith("http"):
74
- # Handle URLs
75
- response = requests.get(file_input)
76
- response.raise_for_status()
77
- return response.content
78
- else:
79
- # Handle local file paths
80
- with open(file_input, "rb") as f:
81
- return f.read()
82
  return file_input.read() if hasattr(file_input, 'read') else file_input
83
 
84
  def ocr_pdf_url(self, pdf_url: str) -> str:
@@ -165,12 +158,7 @@ class OCRProcessor:
165
  temperature=0
166
  )
167
 
168
- # Ensure the response is a dictionary
169
- response_content = chat_response.choices[0].message.content
170
- if isinstance(response_content, list):
171
- response_content = response_content[0] if response_content else "{}"
172
-
173
- content = json.loads(response_content)
174
  return self._format_structured_response(temp_path, content)
175
  except Exception as e:
176
  return self._handle_error("structured OCR", e)
@@ -188,7 +176,7 @@ class OCRProcessor:
188
  def _format_structured_response(file_path: str, content: Dict) -> str:
189
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
190
  valid_langs = [l for l in content.get("languages", [DEFAULT_LANGUAGE]) if l in languages.values()]
191
-
192
  response = {
193
  "file_name": Path(file_path).name,
194
  "topics": content.get("topics", []),
@@ -207,7 +195,7 @@ def create_interface():
207
  placeholder="Enter your Mistral API key here",
208
  type="password"
209
  )
210
-
211
  def initialize_processor(api_key):
212
  try:
213
  processor = OCRProcessor(api_key)
 
3
  import gradio as gr
4
  from mistralai import Mistral
5
  from mistralai.models import OCRResponse
6
+ from mistralai.exceptions import MistralException
7
  from pathlib import Path
8
  from pydantic import BaseModel
9
  import pycountry
10
  import json
11
  import logging
12
+ from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
13
  import tempfile
14
  from typing import Union, Dict, List
15
  from contextlib import contextmanager
 
16
 
17
  # Constants
18
  DEFAULT_LANGUAGE = "English"
 
32
  self.client = Mistral(api_key=self.api_key)
33
  try:
34
  self.client.models.list() # Validate API key
35
+ except MistralException as e:
36
  raise ValueError(f"Invalid API key: {str(e)}")
37
 
38
  @staticmethod
 
52
  if os.path.exists(temp_file.name):
53
  os.unlink(temp_file.name)
54
 
55
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2), retry_if_exception_type=MistralException)
56
  def _call_ocr_api(self, document: Dict) -> OCRResponse:
57
  try:
58
  return self.client.ocr.process(model="mistral-ocr-latest", document=document)
59
+ except MistralException as e:
60
  logger.error(f"OCR API call failed: {str(e)}")
61
  raise
62
 
63
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2), retry_if_exception_type=MistralException)
64
  def _call_chat_complete(self, model: str, messages: List[Dict], **kwargs) -> Dict:
65
  try:
66
  return self.client.chat.complete(model=model, messages=messages, **kwargs)
67
+ except MistralException as e:
68
  logger.error(f"Chat complete API call failed: {str(e)}")
69
  raise
70
 
71
  def _get_file_content(self, file_input: Union[str, bytes]) -> bytes:
72
  if isinstance(file_input, str):
73
+ with open(file_input, "rb") as f:
74
+ return f.read()
 
 
 
 
 
 
 
75
  return file_input.read() if hasattr(file_input, 'read') else file_input
76
 
77
  def ocr_pdf_url(self, pdf_url: str) -> str:
 
158
  temperature=0
159
  )
160
 
161
+ content = json.loads(chat_response.choices[0].message.content if chat_response.choices else "{}")
 
 
 
 
 
162
  return self._format_structured_response(temp_path, content)
163
  except Exception as e:
164
  return self._handle_error("structured OCR", e)
 
176
  def _format_structured_response(file_path: str, content: Dict) -> str:
177
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
178
  valid_langs = [l for l in content.get("languages", [DEFAULT_LANGUAGE]) if l in languages.values()]
179
+
180
  response = {
181
  "file_name": Path(file_path).name,
182
  "topics": content.get("topics", []),
 
195
  placeholder="Enter your Mistral API key here",
196
  type="password"
197
  )
198
+
199
  def initialize_processor(api_key):
200
  try:
201
  processor = OCRProcessor(api_key)