Test_Voice / gradio_image_app.py
raksa-the-wildcats
Add all project files with proper LFS tracking
ee78b3d
"""
Gradio interface for DOLPHIN model
"""
import gradio as gr
import json
import markdown
from markdown.extensions import codehilite
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, VisionEncoderDecoderModel
import torch
import os
from utils.utils import *
from utils.markdown_utils import MarkdownConverter
try:
from mdx_math import MathExtension
MATH_EXTENSION_AVAILABLE = True
except ImportError:
MATH_EXTENSION_AVAILABLE = False
class DOLPHIN:
def __init__(self, model_id_or_path):
"""Initialize the Hugging Face model
Args:
model_id_or_path: Path to local model or Hugging Face model ID
"""
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
if self.device == "cuda":
self.model = self.model.half()
self.tokenizer = self.processor.tokenizer
def chat(self, prompt, image):
"""Process an image or batch of images with the given prompt(s)
Args:
prompt: Text prompt or list of prompts to guide the model
image: PIL Image or list of PIL Images to process
Returns:
Generated text or list of texts from the model
"""
# Check if we're dealing with a batch
is_batch = isinstance(image, list)
if not is_batch:
# Single image, wrap it in a list for consistent processing
images = [image]
prompts = [prompt]
else:
# Batch of images
images = image
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
# Prepare image
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
batch_pixel_values = batch_inputs.pixel_values
if self.device == "cuda":
batch_pixel_values = batch_pixel_values.half()
batch_pixel_values = batch_pixel_values.to(self.device)
# Prepare prompt
prompts = [f"<s>{p} <Answer/>" for p in prompts]
batch_prompt_inputs = self.tokenizer(
prompts,
add_special_tokens=False,
return_tensors="pt"
)
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
# Generate text
outputs = self.model.generate(
pixel_values=batch_pixel_values,
decoder_input_ids=batch_prompt_ids,
decoder_attention_mask=batch_attention_mask,
min_length=1,
max_length=4096,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.tokenizer.unk_token_id]],
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
repetition_penalty=1.1,
temperature=1.0
)
# Process output
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
# Clean prompt text from output
results = []
for i, sequence in enumerate(sequences):
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
results.append(cleaned)
# Return a single result for single image input
if not is_batch:
return results[0]
return results
def render_markdown_with_math(markdown_content):
"""Convert markdown to HTML with MathJax support that works in Gradio"""
import re
# Convert basic markdown to HTML first
html_content = markdown.markdown(markdown_content)
# Create a complete HTML document with MathJax
html_with_math = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
line-height: 1.6;
color: #333;
max-width: 100%;
margin: 0;
padding: 20px;
}}
.math-container {{
margin: 15px 0;
}}
.display-math {{
text-align: center;
margin: 20px 0;
}}
.inline-math {{
display: inline;
}}
table {{
border-collapse: collapse;
width: 100%;
margin: 15px 0;
}}
th, td {{
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}}
th {{
background-color: #f2f2f2;
}}
pre {{
background-color: #f5f5f5;
padding: 10px;
border-radius: 4px;
overflow-x: auto;
}}
code {{
background-color: #f5f5f5;
padding: 2px 4px;
border-radius: 3px;
font-family: 'Courier New', monospace;
}}
</style>
<script>
window.MathJax = {{
tex: {{
inlineMath: [['$', '$'], ['\\\\(', '\\\\)']],
displayMath: [['$$', '$$'], ['\\\\[', '\\\\]']],
processEscapes: true,
processEnvironments: true,
tags: 'ams',
autoload: {{
color: [],
colorv2: ['color']
}},
packages: {{'[+]': ['noerrors']}}
}},
options: {{
ignoreHtmlClass: 'tex2jax_ignore',
processHtmlClass: 'tex2jax_process'
}},
loader: {{
load: ['[tex]/noerrors']
}}
}};
// Function to trigger MathJax processing after content loads
function processMath() {{
if (window.MathJax && window.MathJax.typesetPromise) {{
window.MathJax.typesetPromise().catch(function (err) {{
console.log('MathJax typeset failed: ' + err.message);
}});
}}
}}
// Process math when page loads
document.addEventListener('DOMContentLoaded', function() {{
setTimeout(processMath, 100);
}});
// Also process when MathJax loads
window.addEventListener('load', function() {{
setTimeout(processMath, 200);
}});
</script>
<script type="text/javascript" id="MathJax-script" async
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"
onload="processMath()">
</script>
</head>
<body>
<div class="tex2jax_process">
{html_content}
</div>
<script>
// Additional processing trigger
setTimeout(function() {{
if (window.MathJax && window.MathJax.typesetPromise) {{
window.MathJax.typesetPromise();
}}
}}, 500);
</script>
</body>
</html>
"""
return html_with_math
def process_elements(layout_results, padded_image, dims, model, max_batch_size=16, save_dir=None, image_name="gradio_session"):
"""Parse all document elements with parallel decoding"""
layout_results = parse_layout_string(layout_results)
# Store text and table elements separately
text_elements = [] # Text elements
table_elements = [] # Table elements
figure_results = [] # Image elements (saved as files)
previous_box = None
reading_order = 0
# Setup output directories if save_dir is provided
if save_dir:
setup_output_dirs(save_dir)
# Collect elements to process and group by type
for bbox, label in layout_results:
try:
# Adjust coordinates
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
bbox, padded_image, dims, previous_box
)
# Crop and parse element
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
if label == "fig":
# Convert cropped OpenCV image to PIL
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
# Apply margin cropping to remove white space around the figure
pil_crop = crop_margin(pil_crop)
# Convert to base64 for Gradio display (works better than file paths)
import base64
import io
buffered = io.BytesIO()
pil_crop.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
# Create data URI for direct embedding in markdown
data_uri = f"data:image/png;base64,{img_base64}"
figure_results.append(
{
"label": label,
"text": data_uri, # Pass base64 directly to _handle_figure
"figure_base64": data_uri,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"reading_order": reading_order,
}
)
else:
# Prepare element for parsing
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
element_info = {
"crop": pil_crop,
"label": label,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"reading_order": reading_order,
}
# Group by type
if label == "tab":
table_elements.append(element_info)
else: # Text elements
text_elements.append(element_info)
reading_order += 1
except Exception as e:
print(f"Error processing bbox with label {label}: {str(e)}")
continue
# Initialize results list
recognition_results = figure_results.copy()
# Process text elements (in batches)
if text_elements:
text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
recognition_results.extend(text_results)
# Process table elements (in batches)
if table_elements:
table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
recognition_results.extend(table_results)
# Sort elements by reading order
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
return recognition_results
def process_element_batch(elements, model, prompt, max_batch_size=16):
"""Process elements of the same type in batches"""
results = []
# Determine batch size
batch_size = len(elements)
if max_batch_size is not None and max_batch_size > 0:
batch_size = min(batch_size, max_batch_size)
# Process in batches
for i in range(0, len(elements), batch_size):
batch_elements = elements[i:i+batch_size]
crops_list = [elem["crop"] for elem in batch_elements]
# Use the same prompt for all elements in the batch
prompts_list = [prompt] * len(crops_list)
# Batch inference
batch_results = model.chat(prompts_list, crops_list)
# Add results
for j, result in enumerate(batch_results):
elem = batch_elements[j]
results.append({
"label": elem["label"],
"bbox": elem["bbox"],
"text": result.strip(),
"reading_order": elem["reading_order"],
})
return results
# Initialize model
model_path = "./hf_model"
if not os.path.exists(model_path):
model_path = "ByteDance/DOLPHIN"
try:
dolphin_model = DOLPHIN(model_path)
print(f"Model loaded successfully from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
dolphin_model = None
def process_image(image, task_type):
"""Process uploaded image and return results in different formats"""
if dolphin_model is None:
return None, "Model not loaded", "Model not loaded", {"error": "Model not loaded"}
if image is None:
return None, "No image uploaded", "No image uploaded", {"error": "No image uploaded"}
try:
# Convert to PIL Image if needed
if hasattr(image, 'convert'):
pil_image = image.convert("RGB")
else:
pil_image = Image.fromarray(image).convert("RGB")
if task_type == "Document Parsing":
# Full document processing with two stages
# Stage 1: Page-level layout and reading order parsing
layout_output = dolphin_model.chat("Parse the reading order of this document.", pil_image)
# Stage 2: Element-level content parsing with image extraction
import tempfile
import uuid
# Create temporary directory for saving figures
temp_dir = tempfile.mkdtemp()
session_id = str(uuid.uuid4())[:8]
padded_image, dims = prepare_image(pil_image)
recognition_results = process_elements(
layout_output,
padded_image,
dims,
dolphin_model,
max_batch_size=16,
save_dir=temp_dir,
image_name=f"session_{session_id}"
)
# Convert to markdown
try:
markdown_converter = MarkdownConverter()
markdown_content = markdown_converter.convert(recognition_results)
except:
# Fallback if markdown converter fails
markdown_content = ""
for element in recognition_results:
if element["label"] == "tab":
markdown_content += f"\n\n{element['text']}\n\n"
elif element["label"] in ["para", "title", "sec", "sub_sec"]:
markdown_content += f"{element['text']}\n\n"
elif element["label"] == "fig":
markdown_content += f"{element['text']}\n\n"
# Create structured JSON output
json_output = {
"task_type": task_type,
"layout_parsing": layout_output,
"recognition_results": recognition_results,
"model_info": {
"device": dolphin_model.device,
"model_path": model_path
},
"temp_dir": temp_dir
}
# Return markdown content directly for Gradio's built-in LaTeX support
return pil_image, markdown_content, markdown_content, json_output
else:
# Simple element-level processing for other tasks
if task_type == "Table Extraction":
prompt = "Parse the table in the image."
elif task_type == "Text Reading":
prompt = "Read text in the image."
elif task_type == "Formula Recognition":
prompt = "Read text in the image."
else:
prompt = "Read text in the image."
# Process with model
result = dolphin_model.chat(prompt, pil_image)
# Create JSON output
json_output = {
"task_type": task_type,
"prompt": prompt,
"result": result,
"model_info": {
"device": dolphin_model.device,
"model_path": model_path
}
}
return pil_image, result, result, json_output
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
return None, error_msg, error_msg, {"error": error_msg}
def clear_all():
"""Clear all inputs and outputs"""
return None, None, "", "", {}
# Create Gradio interface
with gr.Blocks(title="DOLPHIN Document AI", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🐬 DOLPHIN Document AI Interface")
gr.Markdown("Upload an image and select a task to process with the DOLPHIN model")
with gr.Row():
# Column 1: Image Upload
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Upload Image")
image_input = gr.Image(
type="pil",
label="Upload Image",
height=600
)
task_type = gr.Dropdown(
choices=["Document Parsing", "Table Extraction", "Text Reading", "Formula Recognition"],
value="Document Parsing",
label="Task Type"
)
with gr.Row():
submit_btn = gr.Button("πŸš€ Submit", variant="primary")
cancel_btn = gr.Button("❌ Clear", variant="secondary")
# Column 2: Image Preview
with gr.Column(scale=1):
gr.Markdown("### πŸ‘οΈ Image Preview")
image_preview = gr.Image(
type="pil",
label="Uploaded Image",
interactive=False,
height=600
)
# Column 3: Results with Tabs
with gr.Column(scale=1):
gr.Markdown("### πŸ“‹ Results")
with gr.Tabs():
with gr.TabItem("πŸ“– Markdown Preview"):
markdown_preview = gr.Markdown(
label="Rendered Markdown",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True}
],
container=True,
height=600
)
with gr.TabItem("πŸ“ Raw Markdown"):
raw_markdown = gr.Code(
label="Raw Markdown Text",
language="markdown",
container=True,
interactive=False,
lines=25
)
with gr.TabItem("πŸ”§ JSON"):
json_output = gr.JSON(
label="JSON Output",
height=600
)
# Event handlers
submit_btn.click(
fn=process_image,
inputs=[image_input, task_type],
outputs=[image_preview, markdown_preview, raw_markdown, json_output]
)
cancel_btn.click(
fn=clear_all,
outputs=[image_input, image_preview, markdown_preview, raw_markdown, json_output]
)
# Auto-update preview when image is uploaded
image_input.change(
fn=lambda img: img if img is not None else None,
inputs=[image_input],
outputs=[image_preview]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)