Spaces:
Runtime error
Runtime error
Sandy2636
commited on
Commit
·
2d6f97d
1
Parent(s):
1d51bfe
Add application file
Browse files
app.py
CHANGED
@@ -15,35 +15,26 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
|
15 |
|
16 |
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
|
17 |
# This will be reset each time the processing function is called.
|
18 |
-
# For a multi-user or more robust app, session state or a proper backend DB would be needed.
|
19 |
processed_files_data = [] # Stores dicts for each file's details and status
|
20 |
person_profiles = {} # Stores dicts for each identified person and their documents
|
21 |
|
22 |
# --- Helper Functions ---
|
23 |
|
24 |
def extract_json_from_text(text):
|
25 |
-
"""
|
26 |
-
Extracts a JSON object from a string, trying common markdown and direct JSON.
|
27 |
-
"""
|
28 |
if not text:
|
29 |
return {"error": "Empty text provided for JSON extraction."}
|
30 |
-
|
31 |
-
# Try to match ```json ... ``` code block
|
32 |
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
33 |
if match_block:
|
34 |
json_str = match_block.group(1)
|
35 |
else:
|
36 |
-
# If no block, assume the text itself might be JSON or wrapped in single backticks
|
37 |
text_stripped = text.strip()
|
38 |
if text_stripped.startswith("`") and text_stripped.endswith("`"):
|
39 |
json_str = text_stripped[1:-1]
|
40 |
else:
|
41 |
-
json_str = text_stripped
|
42 |
-
|
43 |
try:
|
44 |
return json.loads(json_str)
|
45 |
except json.JSONDecodeError as e:
|
46 |
-
# Fallback: Try to find the first '{' and last '}' if initial parsing fails
|
47 |
try:
|
48 |
first_brace = json_str.find('{')
|
49 |
last_brace = json_str.rfind('}')
|
@@ -55,7 +46,6 @@ def extract_json_from_text(text):
|
|
55 |
except json.JSONDecodeError as e2:
|
56 |
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}
|
57 |
|
58 |
-
|
59 |
def get_ocr_prompt():
|
60 |
return f"""You are an advanced OCR and information extraction AI.
|
61 |
Your task is to meticulously analyze this image and extract all relevant information.
|
@@ -89,17 +79,13 @@ def call_openrouter_ocr(image_filepath):
|
|
89 |
try:
|
90 |
with open(image_filepath, "rb") as f:
|
91 |
encoded_image = base64.b64encode(f.read()).decode("utf-8")
|
92 |
-
|
93 |
-
# Basic MIME type guessing, default to jpeg
|
94 |
mime_type = "image/jpeg"
|
95 |
if image_filepath.lower().endswith(".png"):
|
96 |
mime_type = "image/png"
|
97 |
elif image_filepath.lower().endswith(".webp"):
|
98 |
mime_type = "image/webp"
|
99 |
-
|
100 |
data_url = f"data:{mime_type};base64,{encoded_image}"
|
101 |
prompt_text = get_ocr_prompt()
|
102 |
-
|
103 |
payload = {
|
104 |
"model": IMAGE_MODEL,
|
105 |
"messages": [
|
@@ -111,26 +97,23 @@ def call_openrouter_ocr(image_filepath):
|
|
111 |
]
|
112 |
}
|
113 |
],
|
114 |
-
"max_tokens": 3500,
|
115 |
"temperature": 0.1,
|
116 |
}
|
117 |
headers = {
|
118 |
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
119 |
"Content-Type": "application/json",
|
120 |
-
"HTTP-Referer": "https://huggingface.co/spaces/
|
121 |
-
"X-Title": "
|
122 |
}
|
123 |
-
|
124 |
-
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) # 3 min timeout
|
125 |
response.raise_for_status()
|
126 |
result = response.json()
|
127 |
-
|
128 |
if "choices" in result and result["choices"]:
|
129 |
raw_content = result["choices"][0]["message"]["content"]
|
130 |
return extract_json_from_text(raw_content)
|
131 |
else:
|
132 |
return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
|
133 |
-
|
134 |
except requests.exceptions.Timeout:
|
135 |
return {"error": "API request timed out."}
|
136 |
except requests.exceptions.RequestException as e:
|
@@ -142,44 +125,38 @@ def call_openrouter_ocr(image_filepath):
|
|
142 |
return {"error": f"An unexpected error occurred during OCR: {str(e)}"}
|
143 |
|
144 |
def extract_entities_from_ocr(ocr_json):
|
145 |
-
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json
|
146 |
-
|
|
|
|
|
|
|
147 |
|
148 |
fields = ocr_json["extracted_fields"]
|
149 |
doc_type = ocr_json.get("document_type_detected", "Unknown")
|
150 |
-
|
151 |
-
# Normalize potential field names (case-insensitive search)
|
152 |
name_keys = ["full name", "name", "account holder name", "guest name"]
|
153 |
dob_keys = ["date of birth", "dob"]
|
154 |
passport_keys = ["document number", "passport number"]
|
155 |
-
|
156 |
extracted_name = None
|
157 |
for key in name_keys:
|
158 |
for field_key, value in fields.items():
|
159 |
if key == field_key.lower():
|
160 |
extracted_name = str(value) if value else None
|
161 |
break
|
162 |
-
if extracted_name:
|
163 |
-
break
|
164 |
-
|
165 |
extracted_dob = None
|
166 |
for key in dob_keys:
|
167 |
for field_key, value in fields.items():
|
168 |
if key == field_key.lower():
|
169 |
extracted_dob = str(value) if value else None
|
170 |
break
|
171 |
-
if extracted_dob:
|
172 |
-
break
|
173 |
-
|
174 |
extracted_passport_no = None
|
175 |
for key in passport_keys:
|
176 |
for field_key, value in fields.items():
|
177 |
if key == field_key.lower():
|
178 |
-
extracted_passport_no = str(value).replace(" ", "").upper() if value else None
|
179 |
break
|
180 |
-
if extracted_passport_no:
|
181 |
-
break
|
182 |
-
|
183 |
return {
|
184 |
"name": extracted_name,
|
185 |
"dob": extracted_dob,
|
@@ -192,64 +169,42 @@ def normalize_name(name):
|
|
192 |
return "".join(filter(str.isalnum, name)).lower()
|
193 |
|
194 |
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
195 |
-
"""
|
196 |
-
Tries to assign a document to an existing person or creates a new one.
|
197 |
-
Returns a person_key.
|
198 |
-
Updates current_persons_data in place.
|
199 |
-
"""
|
200 |
passport_no = entities.get("passport_no")
|
201 |
name = entities.get("name")
|
202 |
dob = entities.get("dob")
|
203 |
-
|
204 |
-
# 1. Match by Passport Number (strongest identifier)
|
205 |
if passport_no:
|
206 |
for p_key, p_data in current_persons_data.items():
|
207 |
if passport_no in p_data.get("passport_numbers", set()):
|
208 |
p_data["doc_ids"].add(doc_id)
|
209 |
-
# Update person profile with potentially new name/dob if current is missing
|
210 |
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
|
211 |
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
|
212 |
return p_key
|
213 |
-
|
214 |
-
new_person_key = f"person_{passport_no}" # Or more robust ID generation
|
215 |
current_persons_data[new_person_key] = {
|
216 |
-
"canonical_name": name,
|
217 |
-
"canonical_dob": dob,
|
218 |
"names": {normalize_name(name)} if name else set(),
|
219 |
"dobs": {dob} if dob else set(),
|
220 |
-
"passport_numbers": {passport_no},
|
221 |
-
"doc_ids": {doc_id},
|
222 |
"display_name": name or f"Person (ID: {passport_no})"
|
223 |
}
|
224 |
return new_person_key
|
225 |
-
|
226 |
-
# 2. Match by Normalized Name + DOB (if passport not found or not present)
|
227 |
if name and dob:
|
228 |
norm_name = normalize_name(name)
|
229 |
composite_key_nd = f"{norm_name}_{dob}"
|
230 |
for p_key, p_data in current_persons_data.items():
|
231 |
-
# Check if this name and dob combo has been seen for this person
|
232 |
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
|
233 |
p_data["doc_ids"].add(doc_id)
|
234 |
return p_key
|
235 |
-
# New person based on name and DOB
|
236 |
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
|
237 |
current_persons_data[new_person_key] = {
|
238 |
-
"canonical_name": name,
|
239 |
-
"
|
240 |
-
"
|
241 |
-
"dobs": {dob},
|
242 |
-
"passport_numbers": set(),
|
243 |
-
"doc_ids": {doc_id},
|
244 |
"display_name": name
|
245 |
}
|
246 |
return new_person_key
|
247 |
-
|
248 |
-
# 3. If only name, less reliable, create new person (could add fuzzy matching later)
|
249 |
if name:
|
250 |
norm_name = normalize_name(name)
|
251 |
-
# Check if a person with just this name exists and has no other strong identifiers yet
|
252 |
-
# This part can be made more robust, for now, it might create more splits
|
253 |
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
|
254 |
current_persons_data[new_person_key] = {
|
255 |
"canonical_name": name, "canonical_dob": None,
|
@@ -257,8 +212,6 @@ def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
257 |
"doc_ids": {doc_id}, "display_name": name
|
258 |
}
|
259 |
return new_person_key
|
260 |
-
|
261 |
-
# 4. Unclassifiable for now, assign a generic unique person key
|
262 |
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
|
263 |
current_persons_data[generic_person_key] = {
|
264 |
"canonical_name": "Unknown", "canonical_dob": None,
|
@@ -267,17 +220,14 @@ def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
267 |
}
|
268 |
return generic_person_key
|
269 |
|
270 |
-
|
271 |
def format_dataframe_data(current_files_data):
|
272 |
-
# Headers for the dataframe
|
273 |
-
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
|
274 |
df_rows = []
|
275 |
for f_data in current_files_data:
|
276 |
-
entities = f_data.get("entities") or {}
|
277 |
df_rows.append([
|
278 |
-
f_data
|
279 |
-
f_data
|
280 |
-
f_data
|
281 |
entities.get("doc_type", "N/A"),
|
282 |
entities.get("name", "N/A"),
|
283 |
entities.get("dob", "N/A"),
|
@@ -289,37 +239,33 @@ def format_dataframe_data(current_files_data):
|
|
289 |
def format_persons_markdown(current_persons_data, current_files_data):
|
290 |
if not current_persons_data:
|
291 |
return "No persons identified yet."
|
292 |
-
|
293 |
md_parts = ["## Classified Persons & Documents\n"]
|
294 |
for p_key, p_data in current_persons_data.items():
|
295 |
display_name = p_data.get('display_name', p_key)
|
296 |
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
|
297 |
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
|
298 |
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
|
299 |
-
|
300 |
md_parts.append("* Documents:")
|
301 |
doc_ids_for_person = p_data.get("doc_ids", set())
|
302 |
if doc_ids_for_person:
|
303 |
for doc_id in doc_ids_for_person:
|
304 |
-
# Find the filename and detected type from current_files_data
|
305 |
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
|
306 |
if doc_detail:
|
307 |
-
filename = doc_detail
|
308 |
-
|
|
|
309 |
md_parts.append(f" - {filename} (`{doc_type}`)")
|
310 |
else:
|
311 |
-
md_parts.append(f" - Document ID: {doc_id[:8]} (details
|
312 |
else:
|
313 |
md_parts.append(" - No documents currently assigned.")
|
314 |
md_parts.append("\n---\n")
|
315 |
return "\n".join(md_parts)
|
316 |
|
317 |
-
# --- Main Gradio Processing Function (Generator) ---
|
318 |
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
319 |
-
global processed_files_data, person_profiles
|
320 |
processed_files_data = []
|
321 |
person_profiles = {}
|
322 |
-
|
323 |
if not OPENROUTER_API_KEY:
|
324 |
yield (
|
325 |
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
|
@@ -327,74 +273,62 @@ def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
|
327 |
"{}", "API Key Missing. Processing halted."
|
328 |
)
|
329 |
return
|
330 |
-
|
331 |
if not files_list:
|
332 |
yield ([], "No files uploaded.", "{}", "Upload files to begin.")
|
333 |
return
|
334 |
-
|
335 |
-
# Initialize processed_files_data
|
336 |
for i, file_obj in enumerate(files_list):
|
337 |
doc_uid = str(uuid.uuid4())
|
338 |
processed_files_data.append({
|
339 |
"doc_id": doc_uid,
|
340 |
-
"filename": os.path.basename(file_obj.name
|
341 |
-
"filepath": file_obj.name,
|
342 |
"status": "Queued",
|
343 |
"ocr_json": None,
|
344 |
"entities": None,
|
345 |
"assigned_person_key": None
|
346 |
})
|
347 |
-
|
348 |
initial_df_data = format_dataframe_data(processed_files_data)
|
349 |
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
350 |
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
|
351 |
-
|
352 |
-
# Iterate and process each file
|
353 |
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
|
354 |
current_doc_id = file_data_item["doc_id"]
|
355 |
current_filename = file_data_item["filename"]
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
358 |
file_data_item["status"] = "OCR in Progress..."
|
359 |
df_data = format_dataframe_data(processed_files_data)
|
360 |
-
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
361 |
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
|
362 |
-
|
363 |
ocr_result = call_openrouter_ocr(file_data_item["filepath"])
|
364 |
-
file_data_item["ocr_json"] = ocr_result
|
365 |
-
|
366 |
if "error" in ocr_result:
|
367 |
-
file_data_item["status"] = f"OCR Error: {ocr_result['error'][:50]}..."
|
368 |
df_data = format_dataframe_data(processed_files_data)
|
369 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
|
370 |
-
continue
|
371 |
-
|
372 |
file_data_item["status"] = "OCR Done. Extracting Entities..."
|
373 |
df_data = format_dataframe_data(processed_files_data)
|
374 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
|
375 |
-
|
376 |
-
# 2. Entity Extraction
|
377 |
entities = extract_entities_from_ocr(ocr_result)
|
378 |
file_data_item["entities"] = entities
|
379 |
file_data_item["status"] = "Entities Extracted. Classifying..."
|
380 |
-
df_data = format_dataframe_data(processed_files_data)
|
381 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
|
382 |
-
|
383 |
-
# 3. Person Classification / Linking
|
384 |
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
|
385 |
file_data_item["assigned_person_key"] = person_key
|
386 |
file_data_item["status"] = "Classified"
|
387 |
-
|
388 |
df_data = format_dataframe_data(processed_files_data)
|
389 |
-
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
390 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
|
391 |
-
|
392 |
final_df_data = format_dataframe_data(processed_files_data)
|
393 |
final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
394 |
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")
|
395 |
|
396 |
-
|
397 |
-
# --- Gradio UI Layout ---
|
398 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
399 |
gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
|
400 |
gr.Markdown(
|
@@ -402,58 +336,56 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
402 |
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
|
403 |
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
|
404 |
)
|
405 |
-
|
406 |
if not OPENROUTER_API_KEY:
|
407 |
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
|
408 |
-
|
409 |
with gr.Row():
|
410 |
with gr.Column(scale=1):
|
411 |
-
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath")
|
412 |
-
process_button = gr.Button("Process Uploaded Documents", variant="primary")
|
413 |
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
|
414 |
-
|
415 |
gr.Markdown("---")
|
416 |
gr.Markdown("## Document Processing Details")
|
417 |
-
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
|
418 |
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
|
419 |
document_status_df = gr.Dataframe(
|
420 |
headers=dataframe_headers,
|
421 |
-
datatype=["str"] * len(dataframe_headers),
|
422 |
label="Individual Document Status & Extracted Entities",
|
423 |
-
row_count=(
|
424 |
col_count=(len(dataframe_headers), "fixed"),
|
425 |
wrap=True
|
426 |
)
|
427 |
-
|
428 |
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
|
429 |
-
|
430 |
gr.Markdown("---")
|
431 |
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
|
432 |
-
|
433 |
-
# Event Handlers
|
434 |
process_button.click(
|
435 |
fn=process_uploaded_files,
|
436 |
inputs=[files_input],
|
437 |
outputs=[
|
438 |
document_status_df,
|
439 |
person_classification_output_md,
|
440 |
-
ocr_json_output,
|
441 |
overall_status_textbox
|
442 |
]
|
443 |
)
|
444 |
-
|
445 |
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
|
446 |
def display_selected_ocr(evt: gr.SelectData):
|
447 |
-
if evt.index is None or evt.index[0] is None:
|
448 |
-
return "{}"
|
449 |
-
|
450 |
selected_row_index = evt.index[0]
|
451 |
-
|
|
|
|
|
452 |
selected_doc_data = processed_files_data[selected_row_index]
|
453 |
-
if selected_doc_data and selected_doc_data
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
if __name__ == "__main__":
|
459 |
-
demo.queue().launch(debug=True,share=
|
|
|
15 |
|
16 |
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
|
17 |
# This will be reset each time the processing function is called.
|
|
|
18 |
processed_files_data = [] # Stores dicts for each file's details and status
|
19 |
person_profiles = {} # Stores dicts for each identified person and their documents
|
20 |
|
21 |
# --- Helper Functions ---
|
22 |
|
23 |
def extract_json_from_text(text):
|
|
|
|
|
|
|
24 |
if not text:
|
25 |
return {"error": "Empty text provided for JSON extraction."}
|
|
|
|
|
26 |
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
27 |
if match_block:
|
28 |
json_str = match_block.group(1)
|
29 |
else:
|
|
|
30 |
text_stripped = text.strip()
|
31 |
if text_stripped.startswith("`") and text_stripped.endswith("`"):
|
32 |
json_str = text_stripped[1:-1]
|
33 |
else:
|
34 |
+
json_str = text_stripped
|
|
|
35 |
try:
|
36 |
return json.loads(json_str)
|
37 |
except json.JSONDecodeError as e:
|
|
|
38 |
try:
|
39 |
first_brace = json_str.find('{')
|
40 |
last_brace = json_str.rfind('}')
|
|
|
46 |
except json.JSONDecodeError as e2:
|
47 |
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}
|
48 |
|
|
|
49 |
def get_ocr_prompt():
|
50 |
return f"""You are an advanced OCR and information extraction AI.
|
51 |
Your task is to meticulously analyze this image and extract all relevant information.
|
|
|
79 |
try:
|
80 |
with open(image_filepath, "rb") as f:
|
81 |
encoded_image = base64.b64encode(f.read()).decode("utf-8")
|
|
|
|
|
82 |
mime_type = "image/jpeg"
|
83 |
if image_filepath.lower().endswith(".png"):
|
84 |
mime_type = "image/png"
|
85 |
elif image_filepath.lower().endswith(".webp"):
|
86 |
mime_type = "image/webp"
|
|
|
87 |
data_url = f"data:{mime_type};base64,{encoded_image}"
|
88 |
prompt_text = get_ocr_prompt()
|
|
|
89 |
payload = {
|
90 |
"model": IMAGE_MODEL,
|
91 |
"messages": [
|
|
|
97 |
]
|
98 |
}
|
99 |
],
|
100 |
+
"max_tokens": 3500,
|
101 |
"temperature": 0.1,
|
102 |
}
|
103 |
headers = {
|
104 |
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
105 |
"Content-Type": "application/json",
|
106 |
+
"HTTP-Referer": "https://huggingface.co/spaces/YOUR_SPACE",
|
107 |
+
"X-Title": "Gradio Document Processor"
|
108 |
}
|
109 |
+
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180)
|
|
|
110 |
response.raise_for_status()
|
111 |
result = response.json()
|
|
|
112 |
if "choices" in result and result["choices"]:
|
113 |
raw_content = result["choices"][0]["message"]["content"]
|
114 |
return extract_json_from_text(raw_content)
|
115 |
else:
|
116 |
return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
|
|
|
117 |
except requests.exceptions.Timeout:
|
118 |
return {"error": "API request timed out."}
|
119 |
except requests.exceptions.RequestException as e:
|
|
|
125 |
return {"error": f"An unexpected error occurred during OCR: {str(e)}"}
|
126 |
|
127 |
def extract_entities_from_ocr(ocr_json):
|
128 |
+
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json.get("extracted_fields"), dict):
|
129 |
+
doc_type_from_ocr = "Unknown"
|
130 |
+
if isinstance(ocr_json, dict): # ocr_json itself might be an error dict
|
131 |
+
doc_type_from_ocr = ocr_json.get("document_type_detected", "Unknown (error in OCR)")
|
132 |
+
return {"name": None, "dob": None, "passport_no": None, "doc_type": doc_type_from_ocr}
|
133 |
|
134 |
fields = ocr_json["extracted_fields"]
|
135 |
doc_type = ocr_json.get("document_type_detected", "Unknown")
|
|
|
|
|
136 |
name_keys = ["full name", "name", "account holder name", "guest name"]
|
137 |
dob_keys = ["date of birth", "dob"]
|
138 |
passport_keys = ["document number", "passport number"]
|
|
|
139 |
extracted_name = None
|
140 |
for key in name_keys:
|
141 |
for field_key, value in fields.items():
|
142 |
if key == field_key.lower():
|
143 |
extracted_name = str(value) if value else None
|
144 |
break
|
145 |
+
if extracted_name: break
|
|
|
|
|
146 |
extracted_dob = None
|
147 |
for key in dob_keys:
|
148 |
for field_key, value in fields.items():
|
149 |
if key == field_key.lower():
|
150 |
extracted_dob = str(value) if value else None
|
151 |
break
|
152 |
+
if extracted_dob: break
|
|
|
|
|
153 |
extracted_passport_no = None
|
154 |
for key in passport_keys:
|
155 |
for field_key, value in fields.items():
|
156 |
if key == field_key.lower():
|
157 |
+
extracted_passport_no = str(value).replace(" ", "").upper() if value else None
|
158 |
break
|
159 |
+
if extracted_passport_no: break
|
|
|
|
|
160 |
return {
|
161 |
"name": extracted_name,
|
162 |
"dob": extracted_dob,
|
|
|
169 |
return "".join(filter(str.isalnum, name)).lower()
|
170 |
|
171 |
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
|
|
|
|
|
|
|
|
|
172 |
passport_no = entities.get("passport_no")
|
173 |
name = entities.get("name")
|
174 |
dob = entities.get("dob")
|
|
|
|
|
175 |
if passport_no:
|
176 |
for p_key, p_data in current_persons_data.items():
|
177 |
if passport_no in p_data.get("passport_numbers", set()):
|
178 |
p_data["doc_ids"].add(doc_id)
|
|
|
179 |
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
|
180 |
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
|
181 |
return p_key
|
182 |
+
new_person_key = f"person_{passport_no}"
|
|
|
183 |
current_persons_data[new_person_key] = {
|
184 |
+
"canonical_name": name, "canonical_dob": dob,
|
|
|
185 |
"names": {normalize_name(name)} if name else set(),
|
186 |
"dobs": {dob} if dob else set(),
|
187 |
+
"passport_numbers": {passport_no}, "doc_ids": {doc_id},
|
|
|
188 |
"display_name": name or f"Person (ID: {passport_no})"
|
189 |
}
|
190 |
return new_person_key
|
|
|
|
|
191 |
if name and dob:
|
192 |
norm_name = normalize_name(name)
|
193 |
composite_key_nd = f"{norm_name}_{dob}"
|
194 |
for p_key, p_data in current_persons_data.items():
|
|
|
195 |
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
|
196 |
p_data["doc_ids"].add(doc_id)
|
197 |
return p_key
|
|
|
198 |
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
|
199 |
current_persons_data[new_person_key] = {
|
200 |
+
"canonical_name": name, "canonical_dob": dob,
|
201 |
+
"names": {norm_name}, "dobs": {dob},
|
202 |
+
"passport_numbers": set(), "doc_ids": {doc_id},
|
|
|
|
|
|
|
203 |
"display_name": name
|
204 |
}
|
205 |
return new_person_key
|
|
|
|
|
206 |
if name:
|
207 |
norm_name = normalize_name(name)
|
|
|
|
|
208 |
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
|
209 |
current_persons_data[new_person_key] = {
|
210 |
"canonical_name": name, "canonical_dob": None,
|
|
|
212 |
"doc_ids": {doc_id}, "display_name": name
|
213 |
}
|
214 |
return new_person_key
|
|
|
|
|
215 |
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
|
216 |
current_persons_data[generic_person_key] = {
|
217 |
"canonical_name": "Unknown", "canonical_dob": None,
|
|
|
220 |
}
|
221 |
return generic_person_key
|
222 |
|
|
|
223 |
def format_dataframe_data(current_files_data):
|
|
|
|
|
224 |
df_rows = []
|
225 |
for f_data in current_files_data:
|
226 |
+
entities = f_data.get("entities") or {} # CORRECTED LINE HERE
|
227 |
df_rows.append([
|
228 |
+
f_data.get("doc_id", "N/A")[:8],
|
229 |
+
f_data.get("filename", "N/A"),
|
230 |
+
f_data.get("status", "N/A"),
|
231 |
entities.get("doc_type", "N/A"),
|
232 |
entities.get("name", "N/A"),
|
233 |
entities.get("dob", "N/A"),
|
|
|
239 |
def format_persons_markdown(current_persons_data, current_files_data):
|
240 |
if not current_persons_data:
|
241 |
return "No persons identified yet."
|
|
|
242 |
md_parts = ["## Classified Persons & Documents\n"]
|
243 |
for p_key, p_data in current_persons_data.items():
|
244 |
display_name = p_data.get('display_name', p_key)
|
245 |
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
|
246 |
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
|
247 |
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
|
|
|
248 |
md_parts.append("* Documents:")
|
249 |
doc_ids_for_person = p_data.get("doc_ids", set())
|
250 |
if doc_ids_for_person:
|
251 |
for doc_id in doc_ids_for_person:
|
|
|
252 |
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
|
253 |
if doc_detail:
|
254 |
+
filename = doc_detail.get("filename", "Unknown File")
|
255 |
+
doc_entities = doc_detail.get("entities") or {}
|
256 |
+
doc_type = doc_entities.get("doc_type", "Unknown Type")
|
257 |
md_parts.append(f" - {filename} (`{doc_type}`)")
|
258 |
else:
|
259 |
+
md_parts.append(f" - Document ID: {doc_id[:8]} (details error)")
|
260 |
else:
|
261 |
md_parts.append(" - No documents currently assigned.")
|
262 |
md_parts.append("\n---\n")
|
263 |
return "\n".join(md_parts)
|
264 |
|
|
|
265 |
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
266 |
+
global processed_files_data, person_profiles
|
267 |
processed_files_data = []
|
268 |
person_profiles = {}
|
|
|
269 |
if not OPENROUTER_API_KEY:
|
270 |
yield (
|
271 |
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
|
|
|
273 |
"{}", "API Key Missing. Processing halted."
|
274 |
)
|
275 |
return
|
|
|
276 |
if not files_list:
|
277 |
yield ([], "No files uploaded.", "{}", "Upload files to begin.")
|
278 |
return
|
|
|
|
|
279 |
for i, file_obj in enumerate(files_list):
|
280 |
doc_uid = str(uuid.uuid4())
|
281 |
processed_files_data.append({
|
282 |
"doc_id": doc_uid,
|
283 |
+
"filename": os.path.basename(file_obj.name if hasattr(file_obj, 'name') else f"file_{i+1}.unknown"),
|
284 |
+
"filepath": file_obj.name if hasattr(file_obj, 'name') else None, # file_obj itself is filepath if from gr.Files type="filepath"
|
285 |
"status": "Queued",
|
286 |
"ocr_json": None,
|
287 |
"entities": None,
|
288 |
"assigned_person_key": None
|
289 |
})
|
|
|
290 |
initial_df_data = format_dataframe_data(processed_files_data)
|
291 |
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
292 |
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
|
|
|
|
|
293 |
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
|
294 |
current_doc_id = file_data_item["doc_id"]
|
295 |
current_filename = file_data_item["filename"]
|
296 |
+
if not file_data_item["filepath"]: # Check if filepath is valid
|
297 |
+
file_data_item["status"] = "Error: Invalid file path"
|
298 |
+
df_data = format_dataframe_data(processed_files_data)
|
299 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
300 |
+
yield(df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) Error with file {current_filename}")
|
301 |
+
continue
|
302 |
+
|
303 |
file_data_item["status"] = "OCR in Progress..."
|
304 |
df_data = format_dataframe_data(processed_files_data)
|
305 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
306 |
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
|
|
|
307 |
ocr_result = call_openrouter_ocr(file_data_item["filepath"])
|
308 |
+
file_data_item["ocr_json"] = ocr_result
|
|
|
309 |
if "error" in ocr_result:
|
310 |
+
file_data_item["status"] = f"OCR Error: {str(ocr_result['error'])[:50]}..."
|
311 |
df_data = format_dataframe_data(processed_files_data)
|
312 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
|
313 |
+
continue
|
|
|
314 |
file_data_item["status"] = "OCR Done. Extracting Entities..."
|
315 |
df_data = format_dataframe_data(processed_files_data)
|
316 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
|
|
|
|
|
317 |
entities = extract_entities_from_ocr(ocr_result)
|
318 |
file_data_item["entities"] = entities
|
319 |
file_data_item["status"] = "Entities Extracted. Classifying..."
|
320 |
+
df_data = format_dataframe_data(processed_files_data)
|
321 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
|
|
|
|
|
322 |
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
|
323 |
file_data_item["assigned_person_key"] = person_key
|
324 |
file_data_item["status"] = "Classified"
|
|
|
325 |
df_data = format_dataframe_data(processed_files_data)
|
326 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
327 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
|
|
|
328 |
final_df_data = format_dataframe_data(processed_files_data)
|
329 |
final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
330 |
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")
|
331 |
|
|
|
|
|
332 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
333 |
gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
|
334 |
gr.Markdown(
|
|
|
336 |
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
|
337 |
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
|
338 |
)
|
|
|
339 |
if not OPENROUTER_API_KEY:
|
340 |
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
|
|
|
341 |
with gr.Row():
|
342 |
with gr.Column(scale=1):
|
343 |
+
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") # Using filepath
|
344 |
+
process_button = gr.Button("🚀 Process Uploaded Documents", variant="primary")
|
345 |
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
|
|
|
346 |
gr.Markdown("---")
|
347 |
gr.Markdown("## Document Processing Details")
|
|
|
348 |
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
|
349 |
document_status_df = gr.Dataframe(
|
350 |
headers=dataframe_headers,
|
351 |
+
datatype=["str"] * len(dataframe_headers),
|
352 |
label="Individual Document Status & Extracted Entities",
|
353 |
+
row_count=(1, "dynamic"), # Start with 1 row, dynamically grows
|
354 |
col_count=(len(dataframe_headers), "fixed"),
|
355 |
wrap=True
|
356 |
)
|
|
|
357 |
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
|
|
|
358 |
gr.Markdown("---")
|
359 |
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
|
|
|
|
|
360 |
process_button.click(
|
361 |
fn=process_uploaded_files,
|
362 |
inputs=[files_input],
|
363 |
outputs=[
|
364 |
document_status_df,
|
365 |
person_classification_output_md,
|
366 |
+
ocr_json_output,
|
367 |
overall_status_textbox
|
368 |
]
|
369 |
)
|
|
|
370 |
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
|
371 |
def display_selected_ocr(evt: gr.SelectData):
|
372 |
+
if evt.index is None or evt.index[0] is None:
|
373 |
+
return "{}"
|
|
|
374 |
selected_row_index = evt.index[0]
|
375 |
+
# Ensure processed_files_data is accessible here. If it's truly global, it should be.
|
376 |
+
# For safety, one might pass it or make it part of a class if this were more complex.
|
377 |
+
if 0 <= selected_row_index < len(processed_files_data):
|
378 |
selected_doc_data = processed_files_data[selected_row_index]
|
379 |
+
if selected_doc_data and selected_doc_data.get("ocr_json"):
|
380 |
+
# Check if ocr_json is already a dict, if not, try to parse (though it should be)
|
381 |
+
ocr_data_to_display = selected_doc_data["ocr_json"]
|
382 |
+
if isinstance(ocr_data_to_display, str): # Should not happen if stored correctly
|
383 |
+
try:
|
384 |
+
ocr_data_to_display = json.loads(ocr_data_to_display)
|
385 |
+
except json.JSONDecodeError:
|
386 |
+
return json.dumps({"error": "Stored OCR data is not valid JSON string."}, indent=2)
|
387 |
+
return json.dumps(ocr_data_to_display, indent=2, ensure_ascii=False)
|
388 |
+
return json.dumps({ "message": "No OCR data found for selected row or selection out of bounds (check if processing is complete). Current rows: " + str(len(processed_files_data))}, indent=2)
|
389 |
|
390 |
if __name__ == "__main__":
|
391 |
+
demo.queue().launch(debug=True, share=os.environ.get("GRADIO_SHARE", "true").lower() == "true")
|