Svngoku commited on
Commit
971b317
·
verified ·
1 Parent(s): f8fae95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -132
app.py CHANGED
@@ -12,12 +12,17 @@ import tempfile
12
  from typing import Union, Dict, List
13
  from contextlib import contextmanager
14
  import requests
 
15
 
16
  # Constants
17
  DEFAULT_LANGUAGE = "English"
18
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png"]
19
  SUPPORTED_PDF_TYPES = [".pdf"]
20
  TEMP_FILE_EXPIRY = 7200 # 2 hours in seconds
 
 
 
 
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -38,24 +43,73 @@ class OCRProcessor:
38
 
39
  @staticmethod
40
  def _encode_image(image_path: str) -> str:
41
- with open(image_path, "rb") as image_file:
42
- return base64.b64encode(image_file.read()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  @staticmethod
45
  @contextmanager
46
- def _temp_file(content: bytes, suffix: str, keep_alive: bool = False) -> str:
47
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempfile.gettempdir())
48
  try:
49
- logger.info(f"Creating temp file: {temp_file.name}")
50
  temp_file.write(content)
51
  temp_file.close()
52
  yield temp_file.name
53
  finally:
54
- if not keep_alive and os.path.exists(temp_file.name):
55
- logger.info(f"Cleaning up temp file: {temp_file.name}")
56
  os.unlink(temp_file.name)
57
- else:
58
- logger.info(f"Keeping temp file alive: {temp_file.name}")
59
 
60
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
61
  def _call_ocr_api(self, document: Union[DocumentURLChunk, ImageURLChunk]) -> OCRResponse:
@@ -84,48 +138,82 @@ class OCRProcessor:
84
  return f.read()
85
  return file_input.read() if hasattr(file_input, 'read') else file_input
86
 
87
- def ocr_pdf_url(self, pdf_url: str) -> str:
88
  logger.info(f"Processing PDF URL: {pdf_url}")
89
  try:
 
 
 
 
 
 
 
 
 
 
 
 
90
  response = self._call_ocr_api(DocumentURLChunk(document_url=pdf_url))
91
- return self._get_combined_markdown(response)
92
  except Exception as e:
93
- return self._handle_error("PDF URL processing", e)
94
 
95
- def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> tuple[str, str]:
96
  file_name = getattr(pdf_file, 'name', 'unknown')
97
  logger.info(f"Processing uploaded PDF: {file_name}")
98
  try:
99
- content = self._get_file_content(pdf_file)
100
- with self._temp_file(content, ".pdf", keep_alive=True) as temp_path:
101
- uploaded_file = self.client.files.upload(
102
- file={"file_name": temp_path, "content": open(temp_path, "rb")},
103
- purpose="ocr"
104
- )
105
- signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
106
- response = self._call_ocr_api(DocumentURLChunk(document_url=signed_url.url))
107
- return self._get_combined_markdown(response), temp_path
 
 
 
 
 
 
 
108
  except Exception as e:
109
- return self._handle_error("uploaded PDF processing", e), None
110
 
111
- def ocr_image_url(self, image_url: str) -> str:
112
  logger.info(f"Processing image URL: {image_url}")
113
  try:
 
 
 
 
 
 
 
 
 
114
  response = self._call_ocr_api(ImageURLChunk(image_url=image_url))
115
- return self._get_combined_markdown(response)
116
  except Exception as e:
117
- return self._handle_error("image URL processing", e)
118
 
119
  def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> tuple[str, str]:
120
  file_name = getattr(image_file, 'name', 'unknown')
121
  logger.info(f"Processing uploaded image: {file_name}")
122
  try:
123
- content = self._get_file_content(image_file)
124
- with self._temp_file(content, ".jpg", keep_alive=True) as temp_path:
125
- encoded_image = self._encode_image(temp_path)
126
- base64_url = f"data:image/jpeg;base64,{encoded_image}"
127
- response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
128
- return self._get_combined_markdown(response), temp_path
 
 
 
 
 
 
129
  except Exception as e:
130
  return self._handle_error("uploaded image processing", e), None
131
 
@@ -145,32 +233,37 @@ class OCRProcessor:
145
  file_name = getattr(image_file, 'name', 'unknown')
146
  logger.info(f"Processing structured OCR for: {file_name}")
147
  try:
148
- content = self._get_file_content(image_file)
149
- with self._temp_file(content, ".jpg", keep_alive=True) as temp_path:
150
- encoded_image = self._encode_image(temp_path)
151
- base64_url = f"data:image/jpeg;base64,{encoded_image}"
152
- ocr_response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
153
- markdown = self._get_combined_markdown(ocr_response)
154
-
155
- chat_response = self._call_chat_complete(
156
- model="pixtral-12b-latest",
157
- messages=[{
158
- "role": "user",
159
- "content": [
160
- ImageURLChunk(image_url=base64_url),
161
- TextChunk(text=(
162
- f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>.\n"
163
- "Convert this into a sensible structured json response with file_name, topics, languages, and ocr_contents fields"
164
- ))
165
- ]
166
- }],
167
- response_format={"type": "json_object"},
168
- temperature=0
169
- )
170
-
171
- response_content = chat_response.choices[0].message.content
172
- content = json.loads(response_content)
173
- return self._format_structured_response(temp_path, content), temp_path
 
 
 
 
 
174
  except Exception as e:
175
  return self._handle_error("structured OCR", e), None
176
 
@@ -194,6 +287,7 @@ class OCRProcessor:
194
  @staticmethod
195
  def _format_structured_response(file_path: str, content: Dict) -> str:
196
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
 
197
  content_languages = content["languages"] if "languages" in content else [DEFAULT_LANGUAGE]
198
  valid_langs = [l for l in content_languages if l in languages.values()]
199
 
@@ -206,87 +300,58 @@ class OCRProcessor:
206
  return f"```json\n{json.dumps(response, indent=4)}\n```"
207
 
208
  def create_interface():
209
- with gr.Blocks(title="Mistral OCR & Structured Output App") as demo:
210
- gr.Markdown("# Mistral OCR & Structured Output App")
211
- gr.Markdown("Enter your Mistral API key below to use the app. Extract text from PDFs and images or get structured JSON output.")
212
- gr.Markdown("**Note:** After entering your API key, click 'Set API Key' to validate and use it.")
213
-
214
- api_key_input = gr.Textbox(
215
- label="Mistral API Key",
216
- placeholder="Enter your Mistral API key here",
217
- type="password"
218
- )
219
 
220
- def initialize_processor(api_key):
221
  try:
222
- processor = OCRProcessor(api_key)
223
- return processor, "**Success:** API key set and validated!"
224
- except ValueError as e:
225
- return None, f"**Error:** {str(e)}"
226
  except Exception as e:
227
- return None, f"**Error:** Unexpected error: {str(e)}"
228
-
229
- processor_state = gr.State()
230
- api_status = gr.Markdown("API key not set. Please enter and set your key.")
231
 
232
- set_api_button = gr.Button("Set API Key")
233
- set_api_button.click(
234
- fn=initialize_processor,
235
- inputs=api_key_input,
236
- outputs=[processor_state, api_status]
237
  )
238
 
239
- tabs = [
240
- ("OCR with PDF URL", gr.Textbox, "ocr_pdf_url", "PDF URL", None, None),
241
- ("OCR with Uploaded PDF", gr.File, "ocr_uploaded_pdf", "Upload PDF", SUPPORTED_PDF_TYPES, gr.File),
242
- ("OCR with Image URL", gr.Textbox, "ocr_image_url", "Image URL", None, None),
243
- ("OCR with Uploaded Image", gr.File, "ocr_uploaded_image", "Upload Image", SUPPORTED_IMAGE_TYPES, gr.Image),
244
- ("Structured OCR", gr.File, "structured_ocr", "Upload Image", SUPPORTED_IMAGE_TYPES, gr.Image),
245
- ]
246
-
247
- for name, input_type, fn_name, label, file_types, preview_type in tabs:
248
- with gr.Tab(name):
249
- if input_type == gr.Textbox:
250
- inputs = input_type(label=label, placeholder=f"e.g., https://example.com/{label.lower().replace(' ', '')}")
251
- else:
252
- inputs = input_type(label=label, file_types=file_types)
253
-
254
- with gr.Row():
255
- output = gr.Markdown(label="Result")
256
- preview = preview_type(label="Preview") if preview_type else None
257
-
258
- button_label = name.replace("OCR with ", "").replace("Structured ", "Get Structured ")
259
-
260
- def process_with_api(processor, input_data):
261
- if not processor:
262
- return "**Error:** Please set a valid API key first.", None
263
- fn = getattr(processor, fn_name)
264
- return fn(input_data) # Returns tuple (result, preview_path)
265
-
266
- gr.Button(f"Process {button_label}").click(
267
- fn=process_with_api,
268
- inputs=[processor_state, inputs],
269
- outputs=[output, preview] if preview else [output]
270
- )
271
-
272
- with gr.Tab("Document Understanding"):
273
- doc_url = gr.Textbox(label="Document URL", placeholder="e.g., https://arxiv.org/pdf/1805.04770")
274
- question = gr.Textbox(label="Question", placeholder="e.g., What is the last sentence?")
275
- output = gr.Markdown(label="Answer")
276
-
277
- def doc_understanding_with_api(processor, url, q):
278
- if not processor:
279
- return "**Error:** Please set a valid API key first."
280
- return processor.document_understanding(url, q)
281
 
282
- gr.Button("Ask Question").click(
283
- fn=doc_understanding_with_api,
284
- inputs=[processor_state, doc_url, question],
285
- outputs=output
 
 
 
 
 
 
286
  )
287
 
288
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- if __name__ == "__main__":
291
- print(f"===== Application Startup at {os.environ.get('START_TIME', 'Unknown')} =====")
292
- create_interface().launch(share=True, debug=True)
 
12
  from typing import Union, Dict, List
13
  from contextlib import contextmanager
14
  import requests
15
+ import shutil
16
 
17
  # Constants
18
  DEFAULT_LANGUAGE = "English"
19
  SUPPORTED_IMAGE_TYPES = [".jpg", ".png"]
20
  SUPPORTED_PDF_TYPES = [".pdf"]
21
  TEMP_FILE_EXPIRY = 7200 # 2 hours in seconds
22
+ UPLOAD_FOLDER = "uploads" # Local storage folder
23
+
24
+ # Create upload folder if it doesn't exist
25
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
43
 
44
  @staticmethod
45
  def _encode_image(image_path: str) -> str:
46
+ try:
47
+ with open(image_path, "rb") as image_file:
48
+ return base64.b64encode(image_file.read()).decode('utf-8')
49
+ except FileNotFoundError:
50
+ logger.error(f"Error: The file {image_path} was not found.")
51
+ return None
52
+ except Exception as e:
53
+ logger.error(f"Error encoding image: {str(e)}")
54
+ return None
55
+
56
+ @staticmethod
57
+ def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
58
+ """Save uploaded file to local storage and return path"""
59
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
60
+ try:
61
+ if isinstance(file_input, str):
62
+ if file_input.startswith("http"):
63
+ response = requests.get(file_input)
64
+ response.raise_for_status()
65
+ with open(file_path, 'wb') as f:
66
+ f.write(response.content)
67
+ else:
68
+ # Copy file to new location if source and destination are different
69
+ if os.path.abspath(file_input) != os.path.abspath(file_path):
70
+ shutil.copy2(file_input, file_path)
71
+ else:
72
+ return file_input # Return original path if same file
73
+ else:
74
+ with open(file_path, 'wb') as f:
75
+ if hasattr(file_input, 'read'):
76
+ shutil.copyfileobj(file_input, f)
77
+ else:
78
+ f.write(file_input)
79
+ return file_path
80
+ except Exception as e:
81
+ logger.error(f"Error saving file: {str(e)}")
82
+ return None
83
+
84
+ @staticmethod
85
+ def _pdf_to_images(pdf_path: str) -> List[str]:
86
+ """Convert PDF pages to images and return their paths"""
87
+ image_paths = []
88
+ try:
89
+ pdf_document = fitz.open(pdf_path)
90
+ for page_num in range(pdf_document.page_count):
91
+ page = pdf_document[page_num]
92
+ pix = page.get_pixmap()
93
+ image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}.png")
94
+ pix.save(image_path)
95
+ image_paths.append(image_path)
96
+ pdf_document.close()
97
+ return image_paths
98
+ except Exception as e:
99
+ logger.error(f"Error converting PDF to images: {str(e)}")
100
+ return []
101
 
102
  @staticmethod
103
  @contextmanager
104
+ def _temp_file(content: bytes, suffix: str) -> str:
105
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
106
  try:
 
107
  temp_file.write(content)
108
  temp_file.close()
109
  yield temp_file.name
110
  finally:
111
+ if os.path.exists(temp_file.name):
 
112
  os.unlink(temp_file.name)
 
 
113
 
114
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
115
  def _call_ocr_api(self, document: Union[DocumentURLChunk, ImageURLChunk]) -> OCRResponse:
 
138
  return f.read()
139
  return file_input.read() if hasattr(file_input, 'read') else file_input
140
 
141
+ def ocr_pdf_url(self, pdf_url: str) -> tuple[str, List[str]]:
142
  logger.info(f"Processing PDF URL: {pdf_url}")
143
  try:
144
+ # Download and save PDF
145
+ response = requests.get(pdf_url)
146
+ response.raise_for_status()
147
+ filename = pdf_url.split('/')[-1]
148
+ pdf_path = self._save_uploaded_file(response.content, filename)
149
+ if not pdf_path:
150
+ return self._handle_error("PDF saving", Exception("Failed to save PDF")), []
151
+
152
+ # Convert PDF to images for visualization
153
+ image_paths = self._pdf_to_images(pdf_path)
154
+
155
+ # Process with OCR
156
  response = self._call_ocr_api(DocumentURLChunk(document_url=pdf_url))
157
+ return self._get_combined_markdown(response), image_paths
158
  except Exception as e:
159
+ return self._handle_error("PDF URL processing", e), []
160
 
161
+ def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> tuple[str, List[str]]:
162
  file_name = getattr(pdf_file, 'name', 'unknown')
163
  logger.info(f"Processing uploaded PDF: {file_name}")
164
  try:
165
+ # Save uploaded PDF
166
+ pdf_path = self._save_uploaded_file(pdf_file, file_name)
167
+ if not pdf_path:
168
+ return self._handle_error("PDF saving", Exception("Failed to save PDF")), []
169
+
170
+ # Convert PDF to images for visualization
171
+ image_paths = self._pdf_to_images(pdf_path)
172
+
173
+ # Process with OCR
174
+ uploaded_file = self.client.files.upload(
175
+ file={"file_name": pdf_path, "content": open(pdf_path, "rb")},
176
+ purpose="ocr"
177
+ )
178
+ signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
179
+ response = self._call_ocr_api(DocumentURLChunk(document_url=signed_url.url))
180
+ return self._get_combined_markdown(response), image_paths
181
  except Exception as e:
182
+ return self._handle_error("uploaded PDF processing", e), []
183
 
184
+ def ocr_image_url(self, image_url: str) -> tuple[str, str]:
185
  logger.info(f"Processing image URL: {image_url}")
186
  try:
187
+ # Download and save image
188
+ response = requests.get(image_url)
189
+ response.raise_for_status()
190
+ filename = image_url.split('/')[-1]
191
+ image_path = self._save_uploaded_file(response.content, filename)
192
+ if not image_path:
193
+ return self._handle_error("image saving", Exception("Failed to save image")), None
194
+
195
+ # Process with OCR
196
  response = self._call_ocr_api(ImageURLChunk(image_url=image_url))
197
+ return self._get_combined_markdown(response), image_path
198
  except Exception as e:
199
+ return self._handle_error("image URL processing", e), None
200
 
201
  def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> tuple[str, str]:
202
  file_name = getattr(image_file, 'name', 'unknown')
203
  logger.info(f"Processing uploaded image: {file_name}")
204
  try:
205
+ # Save uploaded image
206
+ image_path = self._save_uploaded_file(image_file, file_name)
207
+ if not image_path:
208
+ return self._handle_error("image saving", Exception("Failed to save image")), None
209
+
210
+ # Process with OCR
211
+ encoded_image = self._encode_image(image_path)
212
+ if encoded_image is None:
213
+ return self._handle_error("image encoding", Exception("Failed to encode image")), None
214
+ base64_url = f"data:image/jpeg;base64,{encoded_image}"
215
+ response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
216
+ return self._get_combined_markdown(response), image_path
217
  except Exception as e:
218
  return self._handle_error("uploaded image processing", e), None
219
 
 
233
  file_name = getattr(image_file, 'name', 'unknown')
234
  logger.info(f"Processing structured OCR for: {file_name}")
235
  try:
236
+ # Save uploaded image
237
+ image_path = self._save_uploaded_file(image_file, file_name)
238
+ if not image_path:
239
+ return self._handle_error("image saving", Exception("Failed to save image")), None
240
+
241
+ encoded_image = self._encode_image(image_path)
242
+ if encoded_image is None:
243
+ return self._handle_error("image encoding", Exception("Failed to encode image")), None
244
+ base64_url = f"data:image/jpeg;base64,{encoded_image}"
245
+ ocr_response = self._call_ocr_api(ImageURLChunk(image_url=base64_url))
246
+ markdown = self._get_combined_markdown(ocr_response)
247
+
248
+ chat_response = self._call_chat_complete(
249
+ model="pixtral-12b-latest",
250
+ messages=[{
251
+ "role": "user",
252
+ "content": [
253
+ ImageURLChunk(image_url=base64_url),
254
+ TextChunk(text=(
255
+ f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>.\n"
256
+ "Convert this into a sensible structured json response with file_name, topics, languages, and ocr_contents fields"
257
+ ))
258
+ ]
259
+ }],
260
+ response_format={"type": "json_object"},
261
+ temperature=0
262
+ )
263
+
264
+ response_content = chat_response.choices[0].message.content
265
+ content = json.loads(response_content)
266
+ return self._format_structured_response(image_path, content), image_path
267
  except Exception as e:
268
  return self._handle_error("structured OCR", e), None
269
 
 
287
  @staticmethod
288
  def _format_structured_response(file_path: str, content: Dict) -> str:
289
  languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
290
+ # Handle languages as a list instead of using .get()
291
  content_languages = content["languages"] if "languages" in content else [DEFAULT_LANGUAGE]
292
  valid_langs = [l for l in content_languages if l in languages.values()]
293
 
 
300
  return f"```json\n{json.dumps(response, indent=4)}\n```"
301
 
302
  def create_interface():
303
+ with gr.Blocks(title="Mistral OCR App") as demo:
304
+ gr.Markdown("# Mistral OCR App")
305
+
306
+ api_key = gr.Textbox(label="API Key", type="password")
307
+ processor_state = gr.State()
308
+ status = gr.Markdown()
 
 
 
 
309
 
310
+ def init_processor(key):
311
  try:
312
+ processor = OCRProcessor(key)
313
+ return processor, "API key validated!"
 
 
314
  except Exception as e:
315
+ return None, f"Error: {str(e)}"
 
 
 
316
 
317
+ gr.Button("Set API Key").click(
318
+ fn=init_processor,
319
+ inputs=api_key,
320
+ outputs=[processor_state, status]
 
321
  )
322
 
323
+ with gr.Tab("Image OCR"):
324
+ image_input = gr.File(label="Upload Image", file_types=SUPPORTED_IMAGE_TYPES)
325
+ image_preview = gr.Image(label="Image Preview")
326
+ image_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
+ def process_image(processor, image):
329
+ if not processor:
330
+ return "Please set API key first", None
331
+ ocr_result, image_path = processor.ocr_uploaded_image(image)
332
+ return ocr_result, image_path
333
+
334
+ gr.Button("Process Image").click(
335
+ fn=process_image,
336
+ inputs=[processor_state, image_input],
337
+ outputs=[image_output, image_preview]
338
  )
339
 
340
+ with gr.Tab("PDF OCR"):
341
+ pdf_input = gr.File(label="Upload PDF", file_types=SUPPORTED_PDF_TYPES)
342
+ pdf_gallery = gr.Gallery(label="PDF Pages")
343
+ pdf_output = gr.Markdown()
344
+
345
+ def process_pdf(processor, pdf):
346
+ if not processor:
347
+ return "Please set API key first", None
348
+ ocr_result, image_paths = processor.ocr_uploaded_pdf(pdf)
349
+ return ocr_result, image_paths
350
+
351
+ gr.Button("Process PDF").click(
352
+ fn=process_pdf,
353
+ inputs=[processor_state, pdf_input],
354
+ outputs=[pdf_output, pdf_gallery]
355
+ )
356
 
357
+ return demo