Svngoku commited on
Commit
220b45d
·
verified ·
1 Parent(s): 6d820d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -230
app.py CHANGED
@@ -11,243 +11,195 @@ import json
11
  import logging
12
  from tenacity import retry, stop_after_attempt, wait_fixed
13
  import tempfile
14
- from typing import Union
 
15
 
16
- # Set up logging
17
- logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
- # Initialize Mistral client with API key
21
- api_key = os.environ.get("MISTRAL_API_KEY")
22
- if not api_key:
23
- raise ValueError("MISTRAL_API_KEY environment variable is not set. Please configure it.")
24
- client = Mistral(api_key=api_key)
 
25
 
26
- # Helper function to encode image to base64
27
- def encode_image(image_path: str) -> str:
28
- try:
29
  with open(image_path, "rb") as image_file:
30
  return base64.b64encode(image_file.read()).decode('utf-8')
31
- except Exception as e:
32
- logger.error(f"Error encoding image {image_path}: {str(e)}")
33
- raise
34
-
35
- # Retry-enabled API call helpers
36
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
37
- def call_ocr_api(document: dict) -> OCRResponse:
38
- return client.ocr.process(model="mistral-ocr-latest", document=document)
39
-
40
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
41
- def call_chat_complete(model: str, messages: list, **kwargs) -> dict:
42
- return client.chat.complete(model=model, messages=messages, **kwargs)
43
-
44
- # Helper function to get file content (handles both string paths and file-like objects)
45
- def get_file_content(file_input: Union[str, bytes]) -> bytes:
46
- if isinstance(file_input, str): # Gradio 3.x: file path
47
- with open(file_input, "rb") as f:
48
- return f.read()
49
- else: # Gradio 4.x or file-like object
50
- return file_input.read()
51
-
52
- # OCR with PDF URL
53
- def ocr_pdf_url(pdf_url: str) -> str:
54
- logger.info(f"Processing PDF URL: {pdf_url}")
55
- try:
56
- ocr_response = call_ocr_api({"type": "document_url", "document_url": pdf_url})
57
- markdown = ocr_response.pages[0].markdown if ocr_response.pages else "No text extracted or response invalid."
58
- logger.info("Successfully processed PDF URL")
59
- return markdown
60
- except Exception as e:
61
- logger.error(f"Error processing PDF URL: {str(e)}")
62
- return f"**Error:** {str(e)}"
63
-
64
- # OCR with Uploaded PDF
65
- def ocr_uploaded_pdf(pdf_file: Union[str, bytes]) -> str:
66
- logger.info(f"Processing uploaded PDF: {getattr(pdf_file, 'name', 'unknown')}")
67
- temp_path = None
68
- try:
69
- content = get_file_content(pdf_file)
70
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
71
- temp_file.write(content)
72
- temp_path = temp_file.name
73
- uploaded_pdf = client.files.upload(
74
- file={"file_name": temp_path, "content": open(temp_path, "rb")},
75
- purpose="ocr"
76
- )
77
- signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id, expiry=7200) # 2 hours
78
- ocr_response = call_ocr_api({"type": "document_url", "document_url": signed_url.url})
79
- markdown = ocr_response.pages[0].markdown if ocr_response.pages else "No text extracted or response invalid."
80
- logger.info("Successfully processed uploaded PDF")
81
- return markdown
82
- except Exception as e:
83
- logger.error(f"Error processing uploaded PDF: {str(e)}")
84
- return f"**Error:** {str(e)}"
85
- finally:
86
- if temp_path and os.path.exists(temp_path):
87
- os.remove(temp_path)
88
-
89
- # OCR with Image URL
90
- def ocr_image_url(image_url: str) -> str:
91
- logger.info(f"Processing image URL: {image_url}")
92
- try:
93
- ocr_response = call_ocr_api({"type": "image_url", "image_url": image_url})
94
- markdown = ocr_response.pages[0].markdown if ocr_response.pages else "No text extracted or response invalid."
95
- logger.info("Successfully processed image URL")
96
- return markdown
97
- except Exception as e:
98
- logger.error(f"Error processing image URL: {str(e)}")
99
- return f"**Error:** {str(e)}"
100
-
101
- # OCR with Uploaded Image
102
- def ocr_uploaded_image(image_file: Union[str, bytes]) -> str:
103
- logger.info(f"Processing uploaded image: {getattr(image_file, 'name', 'unknown')}")
104
- temp_path = None
105
- try:
106
- content = get_file_content(image_file)
107
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
108
  temp_file.write(content)
109
- temp_path = temp_file.name
110
- encoded_image = encode_image(temp_path)
111
- base64_data_url = f"data:image/jpeg;base64,{encoded_image}"
112
- ocr_response = call_ocr_api({"type": "image_url", "image_url": base64_data_url})
113
- markdown = ocr_response.pages[0].markdown if ocr_response.pages else "No text extracted or response invalid."
114
- logger.info("Successfully processed uploaded image")
115
- return markdown
116
- except Exception as e:
117
- logger.error(f"Error processing uploaded image: {str(e)}")
118
- return f"**Error:** {str(e)}"
119
- finally:
120
- if temp_path and os.path.exists(temp_path):
121
- os.remove(temp_path)
122
-
123
- # Document Understanding
124
- def document_understanding(doc_url: str, question: str) -> str:
125
- logger.info(f"Processing document understanding - URL: {doc_url}, Question: {question}")
126
- try:
127
- messages = [
128
- {"role": "user", "content": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  {"type": "text", "text": question},
130
  {"type": "document_url", "document_url": doc_url}
131
- ]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ]
133
- chat_response = call_chat_complete(model="mistral-small-latest", messages=messages)
134
- content = chat_response.choices[0].message.content if chat_response.choices else "No response received from the API."
135
- logger.info("Successfully processed document understanding")
136
- return content
137
- except Exception as e:
138
- logger.error(f"Error in document understanding: {str(e)}")
139
- return f"**Error:** {str(e)}"
140
-
141
- # Structured OCR Setup
142
- languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
143
-
144
- class LanguageMeta(Enum.__class__):
145
- def __new__(metacls, cls, bases, classdict):
146
- for code, name in languages.items():
147
- classdict[name.upper().replace(' ', '_')] = name
148
- return super().__new__(metacls, cls, bases, classdict)
149
-
150
- class Language(Enum, metaclass=LanguageMeta):
151
- pass
152
-
153
- class StructuredOCR(BaseModel):
154
- file_name: str
155
- topics: list[str]
156
- languages: list[Language]
157
- ocr_contents: dict
158
-
159
- def structured_ocr(image_file: Union[str, bytes]) -> str:
160
- logger.info(f"Processing structured OCR for image: {getattr(image_file, 'name', 'unknown')}")
161
- temp_path = None
162
- try:
163
- content = get_file_content(image_file)
164
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
165
- temp_file.write(content)
166
- temp_path = temp_file.name
167
- image_path = Path(temp_path)
168
- encoded_image = encode_image(temp_path)
169
- base64_data_url = f"data:image/jpeg;base64,{encoded_image}"
170
-
171
- image_response = call_ocr_api({"type": "image_url", "image_url": base64_data_url})
172
- image_ocr_markdown = image_response.pages[0].markdown if image_response.pages else "No text extracted."
173
-
174
- chat_response = call_chat_complete(
175
- model="pixtral-12b-latest",
176
- messages=[{
177
- "role": "user",
178
- "content": [
179
- {"type": "image_url", "image_url": base64_data_url},
180
- {"type": "text", "text": (
181
- f"This is the image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{image_ocr_markdown}\n<END_IMAGE_OCR>.\n"
182
- "Convert this into a structured JSON response with the OCR contents in a sensible dictionary."
183
- )}
184
- ],
185
- }],
186
- response_format={"type": "json_object"},
187
- temperature=0
188
- )
189
-
190
- content = chat_response.choices[0].message.content if chat_response.choices else "{}"
191
- response_dict = json.loads(content)
192
-
193
- language_members = {member.value: member for member in Language}
194
- valid_languages = [l for l in response_dict.get("languages", ["English"]) if l in language_members]
195
- languages = [language_members[l] for l in valid_languages] if valid_languages else [Language.ENGLISH]
196
-
197
- structured_response = StructuredOCR(
198
- file_name=image_path.name,
199
- topics=response_dict.get("topics", []),
200
- languages=languages,
201
- ocr_contents=response_dict.get("ocr_contents", {})
202
- )
203
- logger.info("Successfully processed structured OCR")
204
- return f"```json\n{json.dumps(structured_response.dict(), indent=4)}\n```"
205
- except Exception as e:
206
- logger.error(f"Error processing structured OCR: {str(e)}")
207
- return f"**Error:** {str(e)}"
208
- finally:
209
- if temp_path and os.path.exists(temp_path):
210
- os.remove(temp_path)
211
-
212
- with gr.Blocks(title="Mistral OCR & Structured Output App") as demo:
213
- gr.Markdown("# Mistral OCR & Structured Output App")
214
- gr.Markdown("Extract text from PDFs and images, ask questions about documents, or get structured JSON output!")
215
-
216
- with gr.Tab("OCR with PDF URL"):
217
- pdf_url_input = gr.Textbox(label="PDF URL", placeholder="e.g., https://arxiv.org/pdf/2201.04234")
218
- pdf_url_output = gr.Markdown(label="OCR Result")
219
- pdf_url_button = gr.Button("Process PDF")
220
- pdf_url_button.click(ocr_pdf_url, inputs=pdf_url_input, outputs=pdf_url_output)
221
-
222
- with gr.Tab("OCR with Uploaded PDF"):
223
- pdf_file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
224
- pdf_file_output = gr.Markdown(label="OCR Result")
225
- pdf_file_button = gr.Button("Process Uploaded PDF")
226
- pdf_file_button.click(ocr_uploaded_pdf, inputs=pdf_file_input, outputs=pdf_file_output)
227
-
228
- with gr.Tab("OCR with Image URL"):
229
- image_url_input = gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg")
230
- image_url_output = gr.Markdown(label="OCR Result")
231
- image_url_button = gr.Button("Process Image")
232
- image_url_button.click(ocr_image_url, inputs=image_url_input, outputs=image_url_output)
233
-
234
- with gr.Tab("OCR with Uploaded Image"):
235
- image_file_input = gr.File(label="Upload Image", file_types=[".jpg", ".png"])
236
- image_file_output = gr.Markdown(label="OCR Result")
237
- image_file_button = gr.Button("Process Uploaded Image")
238
- image_file_button.click(ocr_uploaded_image, inputs=image_file_input, outputs=image_file_output)
239
-
240
- with gr.Tab("Document Understanding"):
241
- doc_url_input = gr.Textbox(label="Document URL", placeholder="e.g., https://arxiv.org/pdf/1805.04770")
242
- question_input = gr.Textbox(label="Question", placeholder="e.g., What is the last sentence?")
243
- doc_output = gr.Markdown(label="Answer")
244
- doc_button = gr.Button("Ask Question")
245
- doc_button.click(document_understanding, inputs=[doc_url_input, question_input], outputs=doc_output)
246
-
247
- with gr.Tab("Structured OCR"):
248
- struct_image_input = gr.File(label="Upload Image", file_types=[".jpg", ".png"])
249
- struct_output = gr.Markdown(label="Structured JSON Output")
250
- struct_button = gr.Button("Get Structured Output")
251
- struct_button.click(structured_ocr, inputs=struct_image_input, outputs=struct_output)
252
-
253
- demo.launch(share=True, debug=True)
 
11
  import logging
12
  from tenacity import retry, stop_after_attempt, wait_fixed
13
  import tempfile
14
+ from typing import Union, Optional, Dict, List
15
+ from contextlib import contextmanager
16
 
17
+ # Constants
18
+ DEFAULT_LANGUAGE = "English"
19
+ SUPPORTED_IMAGE_TYPES = [".jpg", ".png"]
20
+ SUPPORTED_PDF_TYPES = [".pdf"]
21
+ TEMP_FILE_EXPIRY = 7200 # 2 hours in seconds
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
25
  logger = logging.getLogger(__name__)
26
 
27
+ class OCRProcessor:
28
+ def __init__(self):
29
+ self.api_key = os.environ.get("MISTRAL_API_KEY")
30
+ if not self.api_key:
31
+ raise ValueError("MISTRAL_API_KEY environment variable is not set")
32
+ self.client = Mistral(api_key=self.api_key)
33
 
34
+ @staticmethod
35
+ def _encode_image(image_path: str) -> str:
 
36
  with open(image_path, "rb") as image_file:
37
  return base64.b64encode(image_file.read()).decode('utf-8')
38
+
39
+ @staticmethod
40
+ @contextmanager
41
+ def _temp_file(content: bytes, suffix: str) -> str:
42
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
43
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  temp_file.write(content)
45
+ temp_file.close()
46
+ yield temp_file.name
47
+ finally:
48
+ if os.path.exists(temp_file.name):
49
+ os.unlink(temp_file.name)
50
+
51
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
52
+ def _call_ocr_api(self, document: Dict) -> OCRResponse:
53
+ return self.client.ocr.process(model="mistral-ocr-latest", document=document)
54
+
55
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
56
+ def _call_chat_complete(self, model: str, messages: List[Dict], **kwargs) -> Dict:
57
+ return self.client.chat.complete(model=model, messages=messages, **kwargs)
58
+
59
+ def _get_file_content(self, file_input: Union[str, bytes]) -> bytes:
60
+ if isinstance(file_input, str):
61
+ with open(file_input, "rb") as f:
62
+ return f.read()
63
+ return file_input.read() if hasattr(file_input, 'read') else file_input
64
+
65
+ def ocr_pdf_url(self, pdf_url: str) -> str:
66
+ logger.info(f"Processing PDF URL: {pdf_url}")
67
+ try:
68
+ response = self._call_ocr_api({"type": "document_url", "document_url": pdf_url})
69
+ return self._extract_markdown(response)
70
+ except Exception as e:
71
+ return self._handle_error("PDF URL processing", e)
72
+
73
+ def ocr_uploaded_pdf(self, pdf_file: Union[str, bytes]) -> str:
74
+ file_name = getattr(pdf_file, 'name', 'unknown')
75
+ logger.info(f"Processing uploaded PDF: {file_name}")
76
+ try:
77
+ content = self._get_file_content(pdf_file)
78
+ with self._temp_file(content, ".pdf") as temp_path:
79
+ uploaded_file = self.client.files.upload(
80
+ file={"file_name": temp_path, "content": open(temp_path, "rb")},
81
+ purpose="ocr"
82
+ )
83
+ signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=TEMP_FILE_EXPIRY)
84
+ response = self._call_ocr_api({"type": "document_url", "document_url": signed_url.url})
85
+ return self._extract_markdown(response)
86
+ except Exception as e:
87
+ return self._handle_error("uploaded PDF processing", e)
88
+
89
+ def ocr_image_url(self, image_url: str) -> str:
90
+ logger.info(f"Processing image URL: {image_url}")
91
+ try:
92
+ response = self._call_ocr_api({"type": "image_url", "image_url": image_url})
93
+ return self._extract_markdown(response)
94
+ except Exception as e:
95
+ return self._handle_error("image URL processing", e)
96
+
97
+ def ocr_uploaded_image(self, image_file: Union[str, bytes]) -> str:
98
+ file_name = getattr(image_file, 'name', 'unknown')
99
+ logger.info(f"Processing uploaded image: {file_name}")
100
+ try:
101
+ content = self._get_file_content(image_file)
102
+ with self._temp_file(content, ".jpg") as temp_path:
103
+ encoded_image = self._encode_image(temp_path)
104
+ base64_url = f"data:image/jpeg;base64,{encoded_image}"
105
+ response = self._call_ocr_api({"type": "image_url", "image_url": base64_url})
106
+ return self._extract_markdown(response)
107
+ except Exception as e:
108
+ return self._handle_error("uploaded image processing", e)
109
+
110
+ def document_understanding(self, doc_url: str, question: str) -> str:
111
+ logger.info(f"Document understanding - URL: {doc_url}, Question: {question}")
112
+ try:
113
+ messages = [{"role": "user", "content": [
114
  {"type": "text", "text": question},
115
  {"type": "document_url", "document_url": doc_url}
116
+ ]}]
117
+ response = self._call_chat_complete(model="mistral-small-latest", messages=messages)
118
+ return response.choices[0].message.content if response.choices else "No response received"
119
+ except Exception as e:
120
+ return self._handle_error("document understanding", e)
121
+
122
+ def structured_ocr(self, image_file: Union[str, bytes]) -> str:
123
+ file_name = getattr(image_file, 'name', 'unknown')
124
+ logger.info(f"Processing structured OCR for: {file_name}")
125
+ try:
126
+ content = self._get_file_content(image_file)
127
+ with self._temp_file(content, ".jpg") as temp_path:
128
+ encoded_image = self._encode_image(temp_path)
129
+ base64_url = f"data:image/jpeg;base64,{encoded_image}"
130
+ ocr_response = self._call_ocr_api({"type": "image_url", "image_url": base64_url})
131
+ markdown = self._extract_markdown(ocr_response)
132
+
133
+ chat_response = self._call_chat_complete(
134
+ model="pixtral-12b-latest",
135
+ messages=[{
136
+ "role": "user",
137
+ "content": [
138
+ {"type": "image_url", "image_url": base64_url},
139
+ {"type": "text", "text": (
140
+ f"OCR result:\n<BEGIN_IMAGE_OCR>\n{markdown}\n<END_IMAGE_OCR>\n"
141
+ "Convert to structured JSON with file_name, topics, languages, and ocr_contents"
142
+ )}
143
+ ]
144
+ }],
145
+ response_format={"type": "json_object"},
146
+ temperature=0
147
+ )
148
+
149
+ content = json.loads(chat_response.choices[0].message.content if chat_response.choices else "{}")
150
+ return self._format_structured_response(temp_path, content)
151
+ except Exception as e:
152
+ return self._handle_error("structured OCR", e)
153
+
154
+ @staticmethod
155
+ def _extract_markdown(response: OCRResponse) -> str:
156
+ return response.pages[0].markdown if response.pages else "No text extracted"
157
+
158
+ @staticmethod
159
+ def _handle_error(context: str, error: Exception) -> str:
160
+ logger.error(f"Error in {context}: {str(error)}")
161
+ return f"**Error:** {str(error)}"
162
+
163
+ @staticmethod
164
+ def _format_structured_response(file_path: str, content: Dict) -> str:
165
+ languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')}
166
+ valid_langs = [l for l in content.get("languages", [DEFAULT_LANGUAGE]) if l in languages.values()]
167
+
168
+ response = {
169
+ "file_name": Path(file_path).name,
170
+ "topics": content.get("topics", []),
171
+ "languages": valid_langs or [DEFAULT_LANGUAGE],
172
+ "ocr_contents": content.get("ocr_contents", {})
173
+ }
174
+ return f"```json\n{json.dumps(response, indent=4)}\n```"
175
+
176
+ def create_interface():
177
+ processor = OCRProcessor()
178
+ with gr.Blocks(title="Mistral OCR & Structured Output App") as demo:
179
+ gr.Markdown("# Mistral OCR & Structured Output App")
180
+ gr.Markdown("Extract text from PDFs and images or get structured JSON output")
181
+
182
+ tabs = [
183
+ ("OCR with PDF URL", gr.Textbox, processor.ocr_pdf_url, "PDF URL"),
184
+ ("OCR with Uploaded PDF", gr.File, processor.ocr_uploaded_pdf, "Upload PDF", SUPPORTED_PDF_TYPES),
185
+ ("OCR with Image URL", gr.Textbox, processor.ocr_image_url, "Image URL"),
186
+ ("OCR with Uploaded Image", gr.File, processor.ocr_uploaded_image, "Upload Image", SUPPORTED_IMAGE_TYPES),
187
+ ("Structured OCR", gr.File, processor.structured_ocr, "Upload Image", SUPPORTED_IMAGE_TYPES),
188
  ]
189
+
190
+ for name, input_type, fn, label, *file_types in tabs:
191
+ with gr.Tab(name):
192
+ inputs = input_type(label=label, file_types=file_types or None)
193
+ output = gr.Markdown(label="Result")
194
+ gr.Button(f"Process {name.split(' with ')[1]}").click(fn, inputs=inputs, outputs=output)
195
+
196
+ with gr.Tab("Document Understanding"):
197
+ doc_url = gr.Textbox(label="Document URL")
198
+ question = gr.Textbox(label="Question")
199
+ output = gr.Markdown(label="Answer")
200
+ gr.Button("Ask Question").click(processor.document_understanding, inputs=[doc_url, question], outputs=output)
201
+
202
+ return demo
203
+
204
+ if __name__ == "__main__":
205
+ create_interface().launch(share=True, debug=True)