Svngoku commited on
Commit
f8fae95
·
verified ·
1 Parent(s): 0253cad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -53
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import base64
3
  import gradio as gr
4
- from mistralai import Mistral
5
  from mistralai.models import OCRResponse
6
  from pathlib import Path
7
  import pycountry
@@ -30,7 +30,9 @@ class OCRProcessor:
30
  self.api_key = api_key
31
  self.client = Mistral(api_key=self.api_key)
32
  try:
33
- self.client.models.list() # Validate API key
 
 
34
  except Exception as e:
35
  raise ValueError(f"Invalid API key: {str(e)}")
36
 
@@ -41,20 +43,24 @@ class OCRProcessor:
41
 
42
  @staticmethod
43
  @contextmanager
44
- def _temp_file(content: bytes, suffix: str) -> str:
45
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
46
  try:
 
47
  temp_file.write(content)
48
  temp_file.close()
49
  yield temp_file.name
50
  finally:
51
- if os.path.exists(temp_file.name):
 
52
  os.unlink(temp_file.name)
 
 
53
 
54
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
55
- def _call_ocr_api(self, document: Dict) -> OCRResponse:
56
  try:
57
- return self.client.ocr.process(model="mistral-ocr-latest", document=document)
58
  except Exception as e:
59
  logger.error(f"OCR API call failed: {str(e)}")
60
  raise
@@ -70,12 +76,10 @@ class OCRProcessor:
70
  def _get_file_content(self, file_input: Union[str, bytes]) -> bytes:
71
  if isinstance(file_input, str):
72
  if file_input.startswith("http"):
73
- # Handle URLs
74
  response = requests.get(file_input)
75
  response.raise_for_status()
76
  return response.content
77
  else:
78
- # Handle local file paths
79
  with open(file_input, "rb") as f:
80
  return f.read()
81
  return file_input.read() if hasattr(file_input, 'read') else file_input
@@ -83,81 +87,81 @@ class OCRProcessor:
83
  def ocr_pdf_url(self, pdf_url: str) -> str:
84
  logger.info(f"Processing PDF URL: {pdf_url}")
85
  try:
86
- response = self._call_ocr_api({"type": "document_url", "document_url": pdf_url})
87
- return self._extract_markdown(response)
88
  except Exception as e:
89
  return self._handle_error("PDF URL processing", e)
90
 
91
- def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> str:
92
  file_name = getattr(pdf_file, 'name', 'unknown')
93
  logger.info(f"Processing uploaded PDF: {file_name}")
94
  try:
95
  content = self._get_file_content(pdf_file)
96
- with self._temp_file(content, ".pdf") as temp_path:
97
  uploaded_file = self.client.files.upload(
98
  file={"file_name": temp_path, "content": open(temp_path, "rb")},
99
  purpose="ocr"
100
  )
101
  signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
102
- response = self._call_ocr_api({"type": "document_url", "document_url": signed_url.url})
103
- return self._extract_markdown(response)
104
  except Exception as e:
105
- return self._handle_error("uploaded PDF processing", e)
106
 
107
  def ocr_image_url(self, image_url: str) -> str:
108
  logger.info(f"Processing image URL: {image_url}")
109
  try:
110
- response = self._call_ocr_api({"type": "image_url", "image_url": image_url})
111
- return self._extract_markdown(response)
112
  except Exception as e:
113
  return self._handle_error("image URL processing", e)
114
 
115
- def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> str:
116
  file_name = getattr(image_file, 'name', 'unknown')
117
  logger.info(f"Processing uploaded image: {file_name}")
118
  try:
119
  content = self._get_file_content(image_file)
120
- with self._temp_file(content, ".jpg") as temp_path:
121
  encoded_image = self._encode_image(temp_path)
122
  base64_url = f"data:image/jpeg;base64,{encoded_image}"
123
- response = self._call_ocr_api({"type": "image_url", "image_url": base64_url})
124
- return self._extract_markdown(response)
125
  except Exception as e:
126
- return self._handle_error("uploaded image processing", e)
127
 
128
  def document_understanding(self, doc_url: str, question: str) -> str:
129
  logger.info(f"Document understanding - URL: {doc_url}, Question: {question}")
130
  try:
131
  messages = [{"role": "user", "content": [
132
- {"type": "text", "text": question},
133
- {"type": "document_url", "document_url": doc_url}
134
  ]}]
135
  response = self._call_chat_complete(model="mistral-small-latest", messages=messages)
136
  return response.choices[0].message.content if response.choices else "No response received"
137
  except Exception as e:
138
  return self._handle_error("document understanding", e)
139
 
140
- def structured_ocr(self, image_file: Union[str, bytes]) -> str:
141
  file_name = getattr(image_file, 'name', 'unknown')
142
  logger.info(f"Processing structured OCR for: {file_name}")
143
  try:
144
  content = self._get_file_content(image_file)
145
- with self._temp_file(content, ".jpg") as temp_path:
146
  encoded_image = self._encode_image(temp_path)
147
  base64_url = f"data:image/jpeg;base64,{encoded_image}"
148
- ocr_response = self._call_ocr_api({"type": "image_url", "image_url": base64_url})
149
- markdown = self._extract_markdown(ocr_response)
150
 
151
  chat_response = self._call_chat_complete(
152
  model="pixtral-12b-latest",
153
  messages=[{
154
  "role": "user",
155
  "content": [
156
- {"type": "image_url", "image_url": base64_url},
157
- {"type": "text", "text": (
158
- f"OCR result:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>\n"
159
- "Convert to structured JSON with file_name, topics, languages, and ocr_contents"
160
- )}
161
  ]
162
  }],
163
  response_format={"type": "json_object"},
@@ -166,13 +170,21 @@ class OCRProcessor:
166
 
167
  response_content = chat_response.choices[0].message.content
168
  content = json.loads(response_content)
169
- return self._format_structured_response(temp_path, content)
170
  except Exception as e:
171
- return self._handle_error("structured OCR", e)
172
-
173
- @staticmethod
174
- def _extract_markdown(response: OCRResponse) -> str:
175
- return response.pages[0].markdown if response.pages else "No text extracted"
 
 
 
 
 
 
 
 
176
 
177
  @staticmethod
178
  def _handle_error(context: str, error: Exception) -> str:
@@ -182,13 +194,14 @@ class OCRProcessor:
182
  @staticmethod
183
  def _format_structured_response(file_path: str, content: Dict) -> str:
184
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
185
- valid_langs = [l for l in content.get("languages", [DEFAULT_LANGUAGE]) if l in languages.values()]
 
186
 
187
  response = {
188
  "file_name": Path(file_path).name,
189
- "topics": content.get("topics", []),
190
  "languages": valid_langs or [DEFAULT_LANGUAGE],
191
- "ocr_contents": content.get("ocr_contents", {})
192
  }
193
  return f"```json\n{json.dumps(response, indent=4)}\n```"
194
 
@@ -224,32 +237,36 @@ def create_interface():
224
  )
225
 
226
  tabs = [
227
- ("OCR with PDF URL", gr.Textbox, "ocr_pdf_url", "PDF URL", None),
228
- ("OCR with Uploaded PDF", gr.File, "ocr_uploaded_pdf", "Upload PDF", SUPPORTED_PDF_TYPES),
229
- ("OCR with Image URL", gr.Textbox, "ocr_image_url", "Image URL", None),
230
- ("OCR with Uploaded Image", gr.File, "ocr_uploaded_image", "Upload Image", SUPPORTED_IMAGE_TYPES),
231
- ("Structured OCR", gr.File, "structured_ocr", "Upload Image", SUPPORTED_IMAGE_TYPES),
232
  ]
233
 
234
- for name, input_type, fn_name, label, file_types in tabs:
235
  with gr.Tab(name):
236
  if input_type == gr.Textbox:
237
  inputs = input_type(label=label, placeholder=f"e.g., https://example.com/{label.lower().replace(' ', '')}")
238
  else:
239
  inputs = input_type(label=label, file_types=file_types)
240
- output = gr.Markdown(label="Result")
 
 
 
 
241
  button_label = name.replace("OCR with ", "").replace("Structured ", "Get Structured ")
242
 
243
  def process_with_api(processor, input_data):
244
  if not processor:
245
- return "**Error:** Please set a valid API key first."
246
  fn = getattr(processor, fn_name)
247
- return fn(input_data)
248
 
249
  gr.Button(f"Process {button_label}").click(
250
  fn=process_with_api,
251
  inputs=[processor_state, inputs],
252
- outputs=output
253
  )
254
 
255
  with gr.Tab("Document Understanding"):
 
1
  import os
2
  import base64
3
  import gradio as gr
4
+ from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
5
  from mistralai.models import OCRResponse
6
  from pathlib import Path
7
  import pycountry
 
30
  self.api_key = api_key
31
  self.client = Mistral(api_key=self.api_key)
32
  try:
33
+ models = self.client.models.list() # Validate API key
34
+ if not models:
35
+ raise ValueError("No models available")
36
  except Exception as e:
37
  raise ValueError(f"Invalid API key: {str(e)}")
38
 
 
43
 
44
  @staticmethod
45
  @contextmanager
46
+ def _temp_file(content: bytes, suffix: str, keep_alive: bool = False) -> str:
47
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempfile.gettempdir())
48
  try:
49
+ logger.info(f"Creating temp file: {temp_file.name}")
50
  temp_file.write(content)
51
  temp_file.close()
52
  yield temp_file.name
53
  finally:
54
+ if not keep_alive and os.path.exists(temp_file.name):
55
+ logger.info(f"Cleaning up temp file: {temp_file.name}")
56
  os.unlink(temp_file.name)
57
+ else:
58
+ logger.info(f"Keeping temp file alive: {temp_file.name}")
59
 
60
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
61
+ def _call_ocr_api(self, document: Union[DocumentURLChunk, ImageURLChunk]) -> OCRResponse:
62
  try:
63
+ return self.client.ocr.process(model="mistral-ocr-latest", document=document, include_image_base64=True)
64
  except Exception as e:
65
  logger.error(f"OCR API call failed: {str(e)}")
66
  raise
 
76
  def _get_file_content(self, file_input: Union[str, bytes]) -> bytes:
77
  if isinstance(file_input, str):
78
  if file_input.startswith("http"):
 
79
  response = requests.get(file_input)
80
  response.raise_for_status()
81
  return response.content
82
  else:
 
83
  with open(file_input, "rb") as f:
84
  return f.read()
85
  return file_input.read() if hasattr(file_input, 'read') else file_input
 
87
  def ocr_pdf_url(self, pdf_url: str) -> str:
88
  logger.info(f"Processing PDF URL: {pdf_url}")
89
  try:
90
+ response = self._call_ocr_api(DocumentURLChunk(document_url=pdf_url))
91
+ return self._get_combined_markdown(response)
92
  except Exception as e:
93
  return self._handle_error("PDF URL processing", e)
94
 
95
+ def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> tuple[str, str]:
96
  file_name = getattr(pdf_file, 'name', 'unknown')
97
  logger.info(f"Processing uploaded PDF: {file_name}")
98
  try:
99
  content = self._get_file_content(pdf_file)
100
+ with self._temp_file(content, ".pdf", keep_alive=True) as temp_path:
101
  uploaded_file = self.client.files.upload(
102
  file={"file_name": temp_path, "content": open(temp_path, "rb")},
103
  purpose="ocr"
104
  )
105
  signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
106
+ response = self._call_ocr_api(DocumentURLChunk(document_url=signed_url.url))
107
+ return self._get_combined_markdown(response), temp_path
108
  except Exception as e:
109
+ return self._handle_error("uploaded PDF processing", e), None
110
 
111
  def ocr_image_url(self, image_url: str) -> str:
112
  logger.info(f"Processing image URL: {image_url}")
113
  try:
114
+ response = self._call_ocr_api(ImageURLChunk(image_url=image_url))
115
+ return self._get_combined_markdown(response)
116
  except Exception as e:
117
  return self._handle_error("image URL processing", e)
118
 
119
+ def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> tuple[str, str]:
120
  file_name = getattr(image_file, 'name', 'unknown')
121
  logger.info(f"Processing uploaded image: {file_name}")
122
  try:
123
  content = self._get_file_content(image_file)
124
+ with self._temp_file(content, ".jpg", keep_alive=True) as temp_path:
125
  encoded_image = self._encode_image(temp_path)
126
  base64_url = f"data:image/jpeg;base64,{encoded_image}"
127
+ response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
128
+ return self._get_combined_markdown(response), temp_path
129
  except Exception as e:
130
+ return self._handle_error("uploaded image processing", e), None
131
 
132
  def document_understanding(self, doc_url: str, question: str) -> str:
133
  logger.info(f"Document understanding - URL: {doc_url}, Question: {question}")
134
  try:
135
  messages = [{"role": "user", "content": [
136
+ TextChunk(text=question),
137
+ DocumentURLChunk(document_url=doc_url)
138
  ]}]
139
  response = self._call_chat_complete(model="mistral-small-latest", messages=messages)
140
  return response.choices[0].message.content if response.choices else "No response received"
141
  except Exception as e:
142
  return self._handle_error("document understanding", e)
143
 
144
+ def structured_ocr(self, image_file: Union[str, bytes]) -> tuple[str, str]:
145
  file_name = getattr(image_file, 'name', 'unknown')
146
  logger.info(f"Processing structured OCR for: {file_name}")
147
  try:
148
  content = self._get_file_content(image_file)
149
+ with self._temp_file(content, ".jpg", keep_alive=True) as temp_path:
150
  encoded_image = self._encode_image(temp_path)
151
  base64_url = f"data:image/jpeg;base64,{encoded_image}"
152
+ ocr_response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
153
+ markdown = self._get_combined_markdown(ocr_response)
154
 
155
  chat_response = self._call_chat_complete(
156
  model="pixtral-12b-latest",
157
  messages=[{
158
  "role": "user",
159
  "content": [
160
+ ImageURLChunk(image_url=base64_url),
161
+ TextChunk(text=(
162
+ f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>.\n"
163
+ "Convert this into a sensible structured json response with file_name, topics, languages, and ocr_contents fields"
164
+ ))
165
  ]
166
  }],
167
  response_format={"type": "json_object"},
 
170
 
171
  response_content = chat_response.choices[0].message.content
172
  content = json.loads(response_content)
173
+ return self._format_structured_response(temp_path, content), temp_path
174
  except Exception as e:
175
+ return self._handle_error("structured OCR", e), None
176
+
177
+ def _get_combined_markdown(self, response: OCRResponse) -> str:
178
+ markdowns = []
179
+ for page in response.pages:
180
+ image_data = {}
181
+ for img in page.images:
182
+ image_data[img.id] = img.image_base64
183
+ markdown = page.markdown
184
+ for img_name, base64_str in image_data.items():
185
+ markdown = markdown.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})")
186
+ markdowns.append(markdown)
187
+ return "\n\n".join(markdowns)
188
 
189
  @staticmethod
190
  def _handle_error(context: str, error: Exception) -> str:
 
194
  @staticmethod
195
  def _format_structured_response(file_path: str, content: Dict) -> str:
196
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
197
+ content_languages = content["languages"] if "languages" in content else [DEFAULT_LANGUAGE]
198
+ valid_langs = [l for l in content_languages if l in languages.values()]
199
 
200
  response = {
201
  "file_name": Path(file_path).name,
202
+ "topics": content["topics"] if "topics" in content else [],
203
  "languages": valid_langs or [DEFAULT_LANGUAGE],
204
+ "ocr_contents": content["ocr_contents"] if "ocr_contents" in content else {}
205
  }
206
  return f"```json\n{json.dumps(response, indent=4)}\n```"
207
 
 
237
  )
238
 
239
  tabs = [
240
+ ("OCR with PDF URL", gr.Textbox, "ocr_pdf_url", "PDF URL", None, None),
241
+ ("OCR with Uploaded PDF", gr.File, "ocr_uploaded_pdf", "Upload PDF", SUPPORTED_PDF_TYPES, gr.File),
242
+ ("OCR with Image URL", gr.Textbox, "ocr_image_url", "Image URL", None, None),
243
+ ("OCR with Uploaded Image", gr.File, "ocr_uploaded_image", "Upload Image", SUPPORTED_IMAGE_TYPES, gr.Image),
244
+ ("Structured OCR", gr.File, "structured_ocr", "Upload Image", SUPPORTED_IMAGE_TYPES, gr.Image),
245
  ]
246
 
247
+ for name, input_type, fn_name, label, file_types, preview_type in tabs:
248
  with gr.Tab(name):
249
  if input_type == gr.Textbox:
250
  inputs = input_type(label=label, placeholder=f"e.g., https://example.com/{label.lower().replace(' ', '')}")
251
  else:
252
  inputs = input_type(label=label, file_types=file_types)
253
+
254
+ with gr.Row():
255
+ output = gr.Markdown(label="Result")
256
+ preview = preview_type(label="Preview") if preview_type else None
257
+
258
  button_label = name.replace("OCR with ", "").replace("Structured ", "Get Structured ")
259
 
260
  def process_with_api(processor, input_data):
261
  if not processor:
262
+ return "**Error:** Please set a valid API key first.", None
263
  fn = getattr(processor, fn_name)
264
+ return fn(input_data) # Returns tuple (result, preview_path)
265
 
266
  gr.Button(f"Process {button_label}").click(
267
  fn=process_with_api,
268
  inputs=[processor_state, inputs],
269
+ outputs=[output, preview] if preview else [output]
270
  )
271
 
272
  with gr.Tab("Document Understanding"):