""" Gradio interface for DOLPHIN model """ import gradio as gr import json import markdown from markdown.extensions import codehilite import cv2 import numpy as np from PIL import Image from transformers import AutoProcessor, VisionEncoderDecoderModel import torch import os from utils.utils import * from utils.markdown_utils import MarkdownConverter try: from mdx_math import MathExtension MATH_EXTENSION_AVAILABLE = True except ImportError: MATH_EXTENSION_AVAILABLE = False class DOLPHIN: def __init__(self, model_id_or_path): """Initialize the Hugging Face model Args: model_id_or_path: Path to local model or Hugging Face model ID """ self.processor = AutoProcessor.from_pretrained(model_id_or_path) self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path) self.model.eval() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) if self.device == "cuda": self.model = self.model.half() self.tokenizer = self.processor.tokenizer def chat(self, prompt, image): """Process an image or batch of images with the given prompt(s) Args: prompt: Text prompt or list of prompts to guide the model image: PIL Image or list of PIL Images to process Returns: Generated text or list of texts from the model """ # Check if we're dealing with a batch is_batch = isinstance(image, list) if not is_batch: # Single image, wrap it in a list for consistent processing images = [image] prompts = [prompt] else: # Batch of images images = image prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) # Prepare image batch_inputs = self.processor(images, return_tensors="pt", padding=True) batch_pixel_values = batch_inputs.pixel_values if self.device == "cuda": batch_pixel_values = batch_pixel_values.half() batch_pixel_values = batch_pixel_values.to(self.device) # Prepare prompt prompts = [f"{p} " for p in prompts] batch_prompt_inputs = self.tokenizer( prompts, add_special_tokens=False, return_tensors="pt" ) batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device) batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device) # Generate text outputs = self.model.generate( pixel_values=batch_pixel_values, decoder_input_ids=batch_prompt_ids, decoder_attention_mask=batch_attention_mask, min_length=1, max_length=4096, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]], return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1, temperature=1.0 ) # Process output sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) # Clean prompt text from output results = [] for i, sequence in enumerate(sequences): cleaned = sequence.replace(prompts[i], "").replace("", "").replace("", "").strip() results.append(cleaned) # Return a single result for single image input if not is_batch: return results[0] return results def render_markdown_with_math(markdown_content): """Convert markdown to HTML with MathJax support that works in Gradio""" import re # Convert basic markdown to HTML first html_content = markdown.markdown(markdown_content) # Create a complete HTML document with MathJax html_with_math = f"""
{html_content}
""" return html_with_math def process_elements(layout_results, padded_image, dims, model, max_batch_size=16, save_dir=None, image_name="gradio_session"): """Parse all document elements with parallel decoding""" layout_results = parse_layout_string(layout_results) # Store text and table elements separately text_elements = [] # Text elements table_elements = [] # Table elements figure_results = [] # Image elements (saved as files) previous_box = None reading_order = 0 # Setup output directories if save_dir is provided if save_dir: setup_output_dirs(save_dir) # Collect elements to process and group by type for bbox, label in layout_results: try: # Adjust coordinates x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( bbox, padded_image, dims, previous_box ) # Crop and parse element cropped = padded_image[y1:y2, x1:x2] if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: if label == "fig": # Convert cropped OpenCV image to PIL pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) # Apply margin cropping to remove white space around the figure pil_crop = crop_margin(pil_crop) # Convert to base64 for Gradio display (works better than file paths) import base64 import io buffered = io.BytesIO() pil_crop.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode() # Create data URI for direct embedding in markdown data_uri = f"data:image/png;base64,{img_base64}" figure_results.append( { "label": label, "text": data_uri, # Pass base64 directly to _handle_figure "figure_base64": data_uri, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order, } ) else: # Prepare element for parsing pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) element_info = { "crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order, } # Group by type if label == "tab": table_elements.append(element_info) else: # Text elements text_elements.append(element_info) reading_order += 1 except Exception as e: print(f"Error processing bbox with label {label}: {str(e)}") continue # Initialize results list recognition_results = figure_results.copy() # Process text elements (in batches) if text_elements: text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size) recognition_results.extend(text_results) # Process table elements (in batches) if table_elements: table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size) recognition_results.extend(table_results) # Sort elements by reading order recognition_results.sort(key=lambda x: x.get("reading_order", 0)) return recognition_results def process_element_batch(elements, model, prompt, max_batch_size=16): """Process elements of the same type in batches""" results = [] # Determine batch size batch_size = len(elements) if max_batch_size is not None and max_batch_size > 0: batch_size = min(batch_size, max_batch_size) # Process in batches for i in range(0, len(elements), batch_size): batch_elements = elements[i:i+batch_size] crops_list = [elem["crop"] for elem in batch_elements] # Use the same prompt for all elements in the batch prompts_list = [prompt] * len(crops_list) # Batch inference batch_results = model.chat(prompts_list, crops_list) # Add results for j, result in enumerate(batch_results): elem = batch_elements[j] results.append({ "label": elem["label"], "bbox": elem["bbox"], "text": result.strip(), "reading_order": elem["reading_order"], }) return results # Initialize model model_path = "./hf_model" if not os.path.exists(model_path): model_path = "ByteDance/DOLPHIN" try: dolphin_model = DOLPHIN(model_path) print(f"Model loaded successfully from {model_path}") except Exception as e: print(f"Error loading model: {e}") dolphin_model = None def process_image(image, task_type): """Process uploaded image and return results in different formats""" if dolphin_model is None: return None, "Model not loaded", "Model not loaded", {"error": "Model not loaded"} if image is None: return None, "No image uploaded", "No image uploaded", {"error": "No image uploaded"} try: # Convert to PIL Image if needed if hasattr(image, 'convert'): pil_image = image.convert("RGB") else: pil_image = Image.fromarray(image).convert("RGB") if task_type == "Document Parsing": # Full document processing with two stages # Stage 1: Page-level layout and reading order parsing layout_output = dolphin_model.chat("Parse the reading order of this document.", pil_image) # Stage 2: Element-level content parsing with image extraction import tempfile import uuid # Create temporary directory for saving figures temp_dir = tempfile.mkdtemp() session_id = str(uuid.uuid4())[:8] padded_image, dims = prepare_image(pil_image) recognition_results = process_elements( layout_output, padded_image, dims, dolphin_model, max_batch_size=16, save_dir=temp_dir, image_name=f"session_{session_id}" ) # Convert to markdown try: markdown_converter = MarkdownConverter() markdown_content = markdown_converter.convert(recognition_results) except: # Fallback if markdown converter fails markdown_content = "" for element in recognition_results: if element["label"] == "tab": markdown_content += f"\n\n{element['text']}\n\n" elif element["label"] in ["para", "title", "sec", "sub_sec"]: markdown_content += f"{element['text']}\n\n" elif element["label"] == "fig": markdown_content += f"{element['text']}\n\n" # Create structured JSON output json_output = { "task_type": task_type, "layout_parsing": layout_output, "recognition_results": recognition_results, "model_info": { "device": dolphin_model.device, "model_path": model_path }, "temp_dir": temp_dir } # Return markdown content directly for Gradio's built-in LaTeX support return pil_image, markdown_content, markdown_content, json_output else: # Simple element-level processing for other tasks if task_type == "Table Extraction": prompt = "Parse the table in the image." elif task_type == "Text Reading": prompt = "Read text in the image." elif task_type == "Formula Recognition": prompt = "Read text in the image." else: prompt = "Read text in the image." # Process with model result = dolphin_model.chat(prompt, pil_image) # Create JSON output json_output = { "task_type": task_type, "prompt": prompt, "result": result, "model_info": { "device": dolphin_model.device, "model_path": model_path } } return pil_image, result, result, json_output except Exception as e: error_msg = f"Error processing image: {str(e)}" return None, error_msg, error_msg, {"error": error_msg} def clear_all(): """Clear all inputs and outputs""" return None, None, "", "", {} # Create Gradio interface with gr.Blocks(title="DOLPHIN Document AI", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🐬 DOLPHIN Document AI Interface") gr.Markdown("Upload an image and select a task to process with the DOLPHIN model") with gr.Row(): # Column 1: Image Upload with gr.Column(scale=1): gr.Markdown("### 📤 Upload Image") image_input = gr.Image( type="pil", label="Upload Image", height=600 ) task_type = gr.Dropdown( choices=["Document Parsing", "Table Extraction", "Text Reading", "Formula Recognition"], value="Document Parsing", label="Task Type" ) with gr.Row(): submit_btn = gr.Button("🚀 Submit", variant="primary") cancel_btn = gr.Button("❌ Clear", variant="secondary") # Column 2: Image Preview with gr.Column(scale=1): gr.Markdown("### 👁️ Image Preview") image_preview = gr.Image( type="pil", label="Uploaded Image", interactive=False, height=600 ) # Column 3: Results with Tabs with gr.Column(scale=1): gr.Markdown("### 📋 Results") with gr.Tabs(): with gr.TabItem("📖 Markdown Preview"): markdown_preview = gr.Markdown( label="Rendered Markdown", latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True} ], container=True, height=600 ) with gr.TabItem("📝 Raw Markdown"): raw_markdown = gr.Code( label="Raw Markdown Text", language="markdown", container=True, interactive=False, lines=25 ) with gr.TabItem("🔧 JSON"): json_output = gr.JSON( label="JSON Output", height=600 ) # Event handlers submit_btn.click( fn=process_image, inputs=[image_input, task_type], outputs=[image_preview, markdown_preview, raw_markdown, json_output] ) cancel_btn.click( fn=clear_all, outputs=[image_input, image_preview, markdown_preview, raw_markdown, json_output] ) # Auto-update preview when image is uploaded image_input.change( fn=lambda img: img if img is not None else None, inputs=[image_input], outputs=[image_preview] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )