Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import json | |
import base64 | |
import tempfile | |
import os | |
from typing import Dict, List, Optional, Literal | |
from datetime import datetime | |
from PIL import Image, ImageDraw, ImageFont | |
import io | |
import spaces | |
import shutil | |
from pathlib import Path | |
from htrflow.volume.volume import Collection | |
from htrflow.pipeline.pipeline import Pipeline | |
PIPELINE_CONFIGS = { | |
"letter_english": { | |
"steps": [ | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
"generation_settings": {"batch_size": 8}, | |
}, | |
}, | |
{ | |
"step": "TextRecognition", | |
"settings": { | |
"model": "TrOCR", | |
"model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
"generation_settings": {"batch_size": 16}, | |
}, | |
}, | |
{"step": "OrderLines"}, | |
] | |
}, | |
"letter_swedish": { | |
"steps": [ | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
"generation_settings": {"batch_size": 8}, | |
}, | |
}, | |
{ | |
"step": "TextRecognition", | |
"settings": { | |
"model": "TrOCR", | |
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, | |
"generation_settings": {"batch_size": 16}, | |
}, | |
}, | |
{"step": "OrderLines"}, | |
] | |
}, | |
"spread_english": { | |
"steps": [ | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
"generation_settings": {"batch_size": 4}, | |
}, | |
}, | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
"generation_settings": {"batch_size": 8}, | |
}, | |
}, | |
{ | |
"step": "TextRecognition", | |
"settings": { | |
"model": "TrOCR", | |
"model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
"generation_settings": {"batch_size": 16}, | |
}, | |
}, | |
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
] | |
}, | |
"spread_swedish": { | |
"steps": [ | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
"generation_settings": {"batch_size": 4}, | |
}, | |
}, | |
{ | |
"step": "Segmentation", | |
"settings": { | |
"model": "yolo", | |
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
"generation_settings": {"batch_size": 8}, | |
}, | |
}, | |
{ | |
"step": "TextRecognition", | |
"settings": { | |
"model": "TrOCR", | |
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, | |
"generation_settings": {"batch_size": 16}, | |
}, | |
}, | |
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
] | |
}, | |
} | |
def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict: | |
"""Process handwritten text recognition on uploaded images using HTRflow pipelines.""" | |
try: | |
if image is None: | |
return {"success": False, "error": "No image provided", "results": None} | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image.save(temp_file.name, "PNG") | |
temp_image_path = temp_file.name | |
try: | |
if custom_settings: | |
try: | |
config = json.loads(custom_settings) | |
except json.JSONDecodeError: | |
return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None} | |
else: | |
config = PIPELINE_CONFIGS[document_type] | |
collection = Collection([temp_image_path]) | |
pipeline = Pipeline.from_config(config) | |
processed_collection = pipeline.run(collection) | |
img_buffer = io.BytesIO() | |
image.save(img_buffer, format="PNG") | |
image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") | |
results = extract_text_results(processed_collection, confidence_threshold) | |
processing_state = { | |
"collection_data": serialize_collection_data(processed_collection), | |
"image_base64": image_base64, | |
"image_size": image.size, | |
"document_type": document_type, | |
"confidence_threshold": confidence_threshold, | |
"timestamp": datetime.now().isoformat(), | |
} | |
return { | |
"success": True, | |
"results": results, | |
"processing_state": json.dumps(processing_state), | |
"metadata": { | |
"total_lines": len(results.get("text_lines", [])), | |
"average_confidence": results.get("average_confidence", 0), | |
"document_type": document_type, | |
"image_dimensions": image.size, | |
}, | |
} | |
finally: | |
if os.path.exists(temp_image_path): | |
os.unlink(temp_image_path) | |
except Exception as e: | |
return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None} | |
def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict: | |
"""Generate interactive visualizations of HTR processing results.""" | |
try: | |
state = json.loads(processing_state) | |
collection_data = state["collection_data"] | |
if image is not None: | |
original_image = image | |
else: | |
image_data = base64.b64decode(state["image_base64"]) | |
original_image = Image.open(io.BytesIO(image_data)) | |
viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence) | |
img_buffer = io.BytesIO() | |
viz_image.save(img_buffer, format="PNG") | |
img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") | |
return { | |
"success": True, | |
"visualization": { | |
"image_base64": img_base64, | |
"image_format": "PNG", | |
"visualization_type": visualization_type, | |
"dimensions": viz_image.size, | |
}, | |
"metadata": {"total_elements": len(collection_data.get("text_elements", []))}, | |
} | |
except Exception as e: | |
return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None} | |
def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict: | |
"""Export HTR results to multiple formats using HTRflow's native export functionality.""" | |
try: | |
state = json.loads(processing_state) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image_data = base64.b64decode(state["image_base64"]) | |
image = Image.open(io.BytesIO(image_data)) | |
image.save(temp_file.name, "PNG") | |
temp_image_path = temp_file.name | |
try: | |
collection = Collection([temp_image_path]) | |
pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]]) | |
processed_collection = pipeline.run(collection) | |
temp_dir = Path(tempfile.mkdtemp()) | |
exports = {} | |
for fmt in output_formats: | |
export_dir = temp_dir / fmt | |
processed_collection.save(directory=str(export_dir), serializer=fmt) | |
export_files = [] | |
for root, _, files in os.walk(export_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
export_files.append({"filename": file, "content": content}) | |
exports[fmt] = export_files | |
shutil.rmtree(temp_dir) | |
return { | |
"success": True, | |
"exports": exports, | |
"export_metadata": { | |
"formats_generated": output_formats, | |
"confidence_filter": confidence_filter, | |
"timestamp": datetime.now().isoformat(), | |
}, | |
} | |
finally: | |
if os.path.exists(temp_image_path): | |
os.unlink(temp_image_path) | |
except Exception as e: | |
return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None} | |
def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict: | |
results = {"extracted_text": "", "text_lines": [], "confidence_scores": []} | |
for page in collection.pages: | |
for node in page.traverse(): | |
if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold: | |
results["text_lines"].append({ | |
"text": node.text, | |
"confidence": node.confidence, | |
"bbox": getattr(node, "bbox", None), | |
}) | |
results["extracted_text"] += node.text + "\n" | |
results["confidence_scores"].append(node.confidence) | |
results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0 | |
return results | |
def serialize_collection_data(collection: Collection) -> Dict: | |
text_elements = [] | |
for page in collection.pages: | |
for node in page.traverse(): | |
if hasattr(node, "text") and node.text: | |
text_elements.append({ | |
"text": node.text, | |
"confidence": getattr(node, "confidence", 1.0), | |
"bbox": getattr(node, "bbox", None), | |
}) | |
return {"text_elements": text_elements} | |
def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence): | |
viz_image = image.copy() | |
draw = ImageDraw.Draw(viz_image) | |
try: | |
font = ImageFont.truetype("arial.ttf", 12) | |
except: | |
font = ImageFont.load_default() | |
for element in collection_data.get("text_elements", []): | |
if element.get("bbox"): | |
bbox = element["bbox"] | |
confidence = element.get("confidence", 1.0) | |
if visualization_type == "overlay": | |
color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0) | |
draw.rectangle(bbox, outline=color, width=2) | |
if show_confidence: | |
draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font) | |
elif visualization_type == "confidence_heatmap": | |
if confidence < 0.5: | |
color = (255, 0, 0, 100) | |
elif confidence < 0.8: | |
color = (255, 255, 0, 100) | |
else: | |
color = (0, 255, 0, 100) | |
overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0)) | |
overlay_draw = ImageDraw.Draw(overlay) | |
overlay_draw.rectangle(bbox, fill=color) | |
viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay) | |
elif visualization_type == "text_regions": | |
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] | |
color = colors[hash(str(bbox)) % len(colors)] | |
draw.rectangle(bbox, outline=color, width=3) | |
return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image | |
def create_htrflow_mcp_server(): | |
demo = gr.TabbedInterface( | |
[ | |
gr.Interface( | |
fn=process_htr, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"), | |
gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"), | |
gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"), | |
], | |
outputs=gr.JSON(label="Processing Results"), | |
title="HTR Processing Tool", | |
description="Process handwritten text using configurable HTRflow pipelines", | |
api_name="process_htr", | |
), | |
gr.Interface( | |
fn=visualize_results, | |
inputs=[ | |
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), | |
gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"), | |
gr.Checkbox(value=True, label="Show Confidence Scores"), | |
gr.Checkbox(value=True, label="Highlight Low Confidence"), | |
gr.Image(type="pil", label="Image (optional)"), | |
], | |
outputs=gr.JSON(label="Visualization Results"), | |
title="Results Visualization Tool", | |
description="Generate interactive visualizations of HTR results", | |
api_name="visualize_results", | |
), | |
gr.Interface( | |
fn=export_results, | |
inputs=[ | |
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), | |
gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"), | |
gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"), | |
], | |
outputs=gr.JSON(label="Export Results"), | |
title="Export Tool", | |
description="Export HTR results to multiple formats", | |
api_name="export_results", | |
), | |
], | |
["HTR Processing", "Results Visualization", "Export Results"], | |
title="HTRflow MCP Server", | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_htrflow_mcp_server() | |
demo.launch(mcp_server=True) |