raksama19 commited on
Commit
30311f2
Β·
verified Β·
1 Parent(s): 2fb44a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +505 -10
app.py CHANGED
@@ -1,26 +1,521 @@
1
  """
2
- HuggingFace Spaces entry point for DOLPHIN PDF Document AI
 
3
  """
4
 
 
 
 
 
 
 
 
 
5
  import os
6
- import sys
 
 
 
 
 
7
 
8
- # Add the current directory to Python path for imports
9
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Import and run the Gradio app
12
- from gradio_pdf_app import demo
13
 
14
  if __name__ == "__main__":
15
- # Launch the app for HuggingFace Spaces
16
  demo.launch(
17
  server_name="0.0.0.0",
18
  server_port=7860,
19
  share=False,
20
  show_error=True,
21
- enable_queue=True,
22
- max_threads=2,
23
- # Additional HF Spaces specific settings
24
  inbrowser=False,
25
  show_tips=False,
26
  quiet=True
 
1
  """
2
+ PDF Document Processing Gradio App for HuggingFace Spaces
3
+ Built on DOLPHIN model for document parsing and analysis
4
  """
5
 
6
+ import gradio as gr
7
+ import json
8
+ import markdown
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
13
+ import torch
14
  import os
15
+ import tempfile
16
+ import uuid
17
+ import base64
18
+ import io
19
+ from utils.utils import *
20
+ from utils.markdown_utils import MarkdownConverter
21
 
22
+ # Math extension is optional for enhanced math rendering
23
+ MATH_EXTENSION_AVAILABLE = False
24
+ try:
25
+ from mdx_math import MathExtension
26
+ MATH_EXTENSION_AVAILABLE = True
27
+ except ImportError:
28
+ # mdx_math is not available in standard PyPI, gracefully continue without it
29
+ pass
30
+
31
+
32
+ class DOLPHIN:
33
+ def __init__(self, model_id_or_path):
34
+ """Initialize the Hugging Face model optimized for HF Spaces
35
+
36
+ Args:
37
+ model_id_or_path: Path to local model or Hugging Face model ID
38
+ """
39
+ self.processor = AutoProcessor.from_pretrained(model_id_or_path)
40
+ self.model = VisionEncoderDecoderModel.from_pretrained(
41
+ model_id_or_path,
42
+ torch_dtype=torch.float16, # Use half precision for memory efficiency
43
+ device_map="auto" if torch.cuda.is_available() else None
44
+ )
45
+ self.model.eval()
46
+
47
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ if not torch.cuda.is_available():
49
+ # Keep full precision on CPU
50
+ self.model = self.model.float()
51
+
52
+ self.tokenizer = self.processor.tokenizer
53
+
54
+ def chat(self, prompt, image):
55
+ """Process an image or batch of images with the given prompt(s)"""
56
+ is_batch = isinstance(image, list)
57
+
58
+ if not is_batch:
59
+ images = [image]
60
+ prompts = [prompt]
61
+ else:
62
+ images = image
63
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
64
+
65
+ # Prepare image
66
+ batch_inputs = self.processor(images, return_tensors="pt", padding=True)
67
+ batch_pixel_values = batch_inputs.pixel_values
68
+
69
+ if torch.cuda.is_available():
70
+ batch_pixel_values = batch_pixel_values.half().to(self.device)
71
+ else:
72
+ batch_pixel_values = batch_pixel_values.to(self.device)
73
+
74
+ # Prepare prompt
75
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
76
+ batch_prompt_inputs = self.tokenizer(
77
+ prompts,
78
+ add_special_tokens=False,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
83
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
84
+
85
+ # Generate text with memory-efficient settings
86
+ with torch.no_grad():
87
+ outputs = self.model.generate(
88
+ pixel_values=batch_pixel_values,
89
+ decoder_input_ids=batch_prompt_ids,
90
+ decoder_attention_mask=batch_attention_mask,
91
+ min_length=1,
92
+ max_length=2048, # Reduced for memory efficiency
93
+ pad_token_id=self.tokenizer.pad_token_id,
94
+ eos_token_id=self.tokenizer.eos_token_id,
95
+ use_cache=True,
96
+ bad_words_ids=[[self.tokenizer.unk_token_id]],
97
+ return_dict_in_generate=True,
98
+ do_sample=False,
99
+ num_beams=1,
100
+ repetition_penalty=1.1,
101
+ temperature=1.0
102
+ )
103
+
104
+ # Process output
105
+ sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
106
+
107
+ # Clean prompt text from output
108
+ results = []
109
+ for i, sequence in enumerate(sequences):
110
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
111
+ results.append(cleaned)
112
+
113
+ if not is_batch:
114
+ return results[0]
115
+ return results
116
+
117
+
118
+ def convert_pdf_to_images_gradio(pdf_file):
119
+ """Convert uploaded PDF file to list of PIL Images"""
120
+ try:
121
+ import pymupdf
122
+
123
+ # Read the uploaded file
124
+ pdf_bytes = pdf_file.read()
125
+
126
+ # Open PDF from bytes
127
+ pdf_document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
128
+
129
+ images = []
130
+ for page_num in range(len(pdf_document)):
131
+ page = pdf_document[page_num]
132
+
133
+ # Render page to image with high DPI for better quality
134
+ mat = pymupdf.Matrix(2.0, 2.0) # 2x zoom for better quality
135
+ pix = page.get_pixmap(matrix=mat)
136
+
137
+ # Convert to PIL Image
138
+ img_data = pix.tobytes("png")
139
+ pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
140
+ images.append(pil_image)
141
+
142
+ pdf_document.close()
143
+ return images
144
+
145
+ except Exception as e:
146
+ raise Exception(f"Error converting PDF: {str(e)}")
147
+
148
+
149
+ def process_pdf_document(pdf_file, model, progress=gr.Progress()):
150
+ """Process uploaded PDF file page by page"""
151
+ if pdf_file is None:
152
+ return "No PDF file uploaded", [], {}
153
+
154
+ try:
155
+ # Convert PDF to images
156
+ progress(0.1, desc="Converting PDF to images...")
157
+ images = convert_pdf_to_images_gradio(pdf_file)
158
+
159
+ if not images:
160
+ return "Failed to convert PDF to images", [], {}
161
+
162
+ # Process each page
163
+ all_results = []
164
+ page_previews = []
165
+
166
+ for page_idx, pil_image in enumerate(images):
167
+ progress((page_idx + 1) / len(images) * 0.8 + 0.1,
168
+ desc=f"Processing page {page_idx + 1}/{len(images)}...")
169
+
170
+ # Stage 1: Layout parsing
171
+ layout_output = model.chat("Parse the reading order of this document.", pil_image)
172
+
173
+ # Stage 2: Element processing with memory optimization
174
+ padded_image, dims = prepare_image(pil_image)
175
+ recognition_results = process_elements_optimized(
176
+ layout_output,
177
+ padded_image,
178
+ dims,
179
+ model,
180
+ max_batch_size=4 # Smaller batch size for memory efficiency
181
+ )
182
+
183
+ # Convert to markdown
184
+ try:
185
+ markdown_converter = MarkdownConverter()
186
+ markdown_content = markdown_converter.convert(recognition_results)
187
+ except:
188
+ # Fallback markdown generation
189
+ markdown_content = generate_fallback_markdown(recognition_results)
190
+
191
+ # Store page results
192
+ page_result = {
193
+ "page_number": page_idx + 1,
194
+ "layout_output": layout_output,
195
+ "elements": recognition_results,
196
+ "markdown": markdown_content
197
+ }
198
+ all_results.append(page_result)
199
+
200
+ # Create page preview with results
201
+ page_preview = {
202
+ "image": pil_image,
203
+ "page_num": page_idx + 1,
204
+ "element_count": len(recognition_results),
205
+ "markdown_preview": markdown_content[:500] + "..." if len(markdown_content) > 500 else markdown_content
206
+ }
207
+ page_previews.append(page_preview)
208
+
209
+ progress(1.0, desc="Processing complete!")
210
+
211
+ # Combine all markdown
212
+ combined_markdown = "\n\n---\n\n".join([
213
+ f"# Page {result['page_number']}\n\n{result['markdown']}"
214
+ for result in all_results
215
+ ])
216
+
217
+ # Create summary JSON
218
+ summary_json = {
219
+ "total_pages": len(images),
220
+ "processing_status": "completed",
221
+ "pages": all_results,
222
+ "model_info": {
223
+ "device": model.device,
224
+ "total_elements": sum(len(page["elements"]) for page in all_results)
225
+ }
226
+ }
227
+
228
+ return combined_markdown, page_previews, summary_json
229
+
230
+ except Exception as e:
231
+ error_msg = f"Error processing PDF: {str(e)}"
232
+ return error_msg, [], {"error": error_msg}
233
+
234
+
235
+ def process_elements_optimized(layout_results, padded_image, dims, model, max_batch_size=4):
236
+ """Optimized element processing for memory efficiency"""
237
+ layout_results = parse_layout_string(layout_results)
238
+
239
+ text_elements = []
240
+ table_elements = []
241
+ figure_results = []
242
+ previous_box = None
243
+ reading_order = 0
244
+
245
+ # Collect elements to process
246
+ for bbox, label in layout_results:
247
+ try:
248
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
249
+ bbox, padded_image, dims, previous_box
250
+ )
251
+
252
+ cropped = padded_image[y1:y2, x1:x2]
253
+ if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
254
+ if label == "fig":
255
+ # Convert to base64 for figure display
256
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
257
+ pil_crop = crop_margin(pil_crop)
258
+
259
+ buffered = io.BytesIO()
260
+ pil_crop.save(buffered, format="PNG")
261
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
262
+ data_uri = f"data:image/png;base64,{img_base64}"
263
+
264
+ figure_results.append({
265
+ "label": label,
266
+ "text": f"![Figure {reading_order}]({data_uri})",
267
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
268
+ "reading_order": reading_order,
269
+ })
270
+ else:
271
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
272
+ element_info = {
273
+ "crop": pil_crop,
274
+ "label": label,
275
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
276
+ "reading_order": reading_order,
277
+ }
278
+
279
+ if label == "tab":
280
+ table_elements.append(element_info)
281
+ else:
282
+ text_elements.append(element_info)
283
+
284
+ reading_order += 1
285
+
286
+ except Exception as e:
287
+ print(f"Error processing element {label}: {str(e)}")
288
+ continue
289
+
290
+ # Process elements in small batches
291
+ recognition_results = figure_results.copy()
292
+
293
+ if text_elements:
294
+ text_results = process_element_batch_optimized(
295
+ text_elements, model, "Read text in the image.", max_batch_size
296
+ )
297
+ recognition_results.extend(text_results)
298
+
299
+ if table_elements:
300
+ table_results = process_element_batch_optimized(
301
+ table_elements, model, "Parse the table in the image.", max_batch_size
302
+ )
303
+ recognition_results.extend(table_results)
304
+
305
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
306
+ return recognition_results
307
+
308
+
309
+ def process_element_batch_optimized(elements, model, prompt, max_batch_size=4):
310
+ """Process elements in small batches for memory efficiency"""
311
+ results = []
312
+ batch_size = min(len(elements), max_batch_size)
313
+
314
+ for i in range(0, len(elements), batch_size):
315
+ batch_elements = elements[i:i+batch_size]
316
+ crops_list = [elem["crop"] for elem in batch_elements]
317
+ prompts_list = [prompt] * len(crops_list)
318
+
319
+ # Process batch
320
+ batch_results = model.chat(prompts_list, crops_list)
321
+
322
+ for j, result in enumerate(batch_results):
323
+ elem = batch_elements[j]
324
+ results.append({
325
+ "label": elem["label"],
326
+ "bbox": elem["bbox"],
327
+ "text": result.strip(),
328
+ "reading_order": elem["reading_order"],
329
+ })
330
+
331
+ # Clear memory
332
+ del crops_list, batch_elements
333
+ if torch.cuda.is_available():
334
+ torch.cuda.empty_cache()
335
+
336
+ return results
337
+
338
+
339
+ def generate_fallback_markdown(recognition_results):
340
+ """Generate basic markdown if converter fails"""
341
+ markdown_content = ""
342
+ for element in recognition_results:
343
+ if element["label"] == "tab":
344
+ markdown_content += f"\n\n{element['text']}\n\n"
345
+ elif element["label"] in ["para", "title", "sec", "sub_sec"]:
346
+ markdown_content += f"{element['text']}\n\n"
347
+ elif element["label"] == "fig":
348
+ markdown_content += f"{element['text']}\n\n"
349
+ return markdown_content
350
+
351
+
352
+ def create_page_gallery(page_previews):
353
+ """Create a gallery view of processed pages"""
354
+ if not page_previews:
355
+ return "No pages processed yet."
356
+
357
+ gallery_html = "<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px;'>"
358
+
359
+ for preview in page_previews:
360
+ gallery_html += f"""
361
+ <div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px;'>
362
+ <h3>Page {preview['page_num']}</h3>
363
+ <p><strong>Elements found:</strong> {preview['element_count']}</p>
364
+ <div style='max-height: 200px; overflow-y: auto; background: #f5f5f5; padding: 10px; border-radius: 4px; font-size: 12px;'>
365
+ {preview['markdown_preview']}
366
+ </div>
367
+ </div>
368
+ """
369
+
370
+ gallery_html += "</div>"
371
+ return gallery_html
372
+
373
+
374
+ # Initialize model
375
+ model_path = "./hf_model"
376
+ if not os.path.exists(model_path):
377
+ model_path = "ByteDance/DOLPHIN"
378
+
379
+ try:
380
+ dolphin_model = DOLPHIN(model_path)
381
+ print(f"Model loaded successfully from {model_path}")
382
+ model_status = f"βœ… Model loaded: {model_path} (Device: {dolphin_model.device})"
383
+ except Exception as e:
384
+ print(f"Error loading model: {e}")
385
+ dolphin_model = None
386
+ model_status = f"❌ Model failed to load: {str(e)}"
387
+
388
+
389
+ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
390
+ """Main processing function for uploaded PDF"""
391
+ if dolphin_model is None:
392
+ return "Model not loaded", "Model not loaded", {}, "Model not loaded"
393
+
394
+ if pdf_file is None:
395
+ return "No PDF uploaded", "No PDF uploaded", {}, "No PDF uploaded"
396
+
397
+ try:
398
+ # Process the PDF
399
+ combined_markdown, page_previews, summary_json = process_pdf_document(
400
+ pdf_file, dolphin_model, progress
401
+ )
402
+
403
+ # Create page gallery
404
+ gallery_html = create_page_gallery(page_previews)
405
+
406
+ return combined_markdown, combined_markdown, summary_json, gallery_html
407
+
408
+ except Exception as e:
409
+ error_msg = f"Error processing PDF: {str(e)}"
410
+ return error_msg, error_msg, {"error": error_msg}, error_msg
411
+
412
+
413
+ def clear_all():
414
+ """Clear all inputs and outputs"""
415
+ return None, "", "", {}, ""
416
+
417
+
418
+ # Create Gradio interface optimized for HuggingFace Spaces
419
+ with gr.Blocks(
420
+ title="DOLPHIN PDF Document AI",
421
+ theme=gr.themes.Soft(),
422
+ css="""
423
+ .main-container { max-width: 1200px; margin: 0 auto; }
424
+ .status-box { padding: 10px; border-radius: 5px; margin: 10px 0; }
425
+ .success { background-color: #d4edda; border: 1px solid #c3e6cb; }
426
+ .error { background-color: #f8d7da; border: 1px solid #f5c6cb; }
427
+ """
428
+ ) as demo:
429
+ gr.Markdown("# 🐬 DOLPHIN PDF Document AI")
430
+ gr.Markdown(
431
+ "Upload a PDF document and process it page by page with the DOLPHIN model. "
432
+ "Optimized for HuggingFace Spaces deployment."
433
+ )
434
+
435
+ # Model status
436
+ gr.Markdown(f"**Model Status:** {model_status}")
437
+
438
+ with gr.Row():
439
+ # Left column: Upload and controls
440
+ with gr.Column(scale=1):
441
+ gr.Markdown("### πŸ“„ Upload PDF Document")
442
+ pdf_input = gr.File(
443
+ file_types=[".pdf"],
444
+ label="Select PDF File",
445
+ height=200
446
+ )
447
+
448
+ with gr.Row():
449
+ process_btn = gr.Button("πŸš€ Process PDF", variant="primary", size="lg")
450
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
451
+
452
+ # Right column: Results tabs
453
+ with gr.Column(scale=2):
454
+ gr.Markdown("### πŸ“Š Processing Results")
455
+
456
+ with gr.Tabs():
457
+ with gr.TabItem("πŸ“– Markdown Output"):
458
+ markdown_output = gr.Markdown(
459
+ label="Processed Document",
460
+ latex_delimiters=[
461
+ {"left": "$$", "right": "$$", "display": True},
462
+ {"left": "$", "right": "$", "display": False}
463
+ ],
464
+ height=600
465
+ )
466
+
467
+ with gr.TabItem("πŸ“ Raw Markdown"):
468
+ raw_markdown = gr.Code(
469
+ label="Raw Markdown Text",
470
+ language="markdown",
471
+ lines=25,
472
+ height=600
473
+ )
474
+
475
+ with gr.TabItem("πŸ” Page Gallery"):
476
+ page_gallery = gr.HTML(
477
+ label="Page Overview",
478
+ height=600
479
+ )
480
+
481
+ with gr.TabItem("πŸ”§ JSON Details"):
482
+ json_output = gr.JSON(
483
+ label="Processing Details",
484
+ height=600
485
+ )
486
+
487
+ # Progress bar
488
+ progress_bar = gr.HTML(visible=False)
489
+
490
+ # Event handlers
491
+ process_btn.click(
492
+ fn=process_uploaded_pdf,
493
+ inputs=[pdf_input],
494
+ outputs=[markdown_output, raw_markdown, json_output, page_gallery],
495
+ show_progress=True
496
+ )
497
+
498
+ clear_btn.click(
499
+ fn=clear_all,
500
+ outputs=[pdf_input, markdown_output, raw_markdown, json_output, page_gallery]
501
+ )
502
+
503
+ # Footer
504
+ gr.Markdown(
505
+ "---\n"
506
+ "**Note:** This app is optimized for NVIDIA T4 deployment on HuggingFace Spaces. "
507
+ "Processing time depends on document complexity and page count."
508
+ )
509
 
 
 
510
 
511
  if __name__ == "__main__":
 
512
  demo.launch(
513
  server_name="0.0.0.0",
514
  server_port=7860,
515
  share=False,
516
  show_error=True,
517
+ enable_queue=True, # Enable queue for better performance
518
+ max_threads=2, # Limit threads for memory efficiency
 
519
  inbrowser=False,
520
  show_tips=False,
521
  quiet=True