htrflow_mcp / app.py
Gabriel's picture
Update app.py
d6e55c9 verified
raw
history blame
16.4 kB
import gradio as gr
import json
import base64
import tempfile
import os
from typing import Dict, List, Optional, Literal
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
import io
import spaces
import shutil
from pathlib import Path
from htrflow.volume.volume import Collection
from htrflow.pipeline.pipeline import Pipeline
PIPELINE_CONFIGS = {
"letter_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"letter_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"spread_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
"spread_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
}
@spaces.GPU
def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "letter_english", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict:
"""Process handwritten text recognition on uploaded images using HTRflow pipelines."""
if image is None:
return {"success": False, "error": "No image provided", "results": None}
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
if custom_settings:
try:
config = json.loads(custom_settings)
except json.JSONDecodeError:
return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None}
else:
config = PIPELINE_CONFIGS[document_type]
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(config)
try:
processed_collection = pipeline.run(collection)
except Exception as pipeline_error:
return {"success": False, "error": f"Pipeline execution failed: {str(pipeline_error)}", "results": None}
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
results = extract_text_results(processed_collection, confidence_threshold)
processing_state = {
"processed_collection": processed_collection,
"image_base64": image_base64,
"image_size": image.size,
"document_type": document_type,
"confidence_threshold": confidence_threshold,
"timestamp": datetime.now().isoformat(),
}
return {
"success": True,
"results": results,
"processing_state": json.dumps(processing_state, default=str),
"metadata": {
"total_lines": len(results.get("text_lines", [])),
"average_confidence": results.get("average_confidence", 0),
"document_type": document_type,
"image_dimensions": image.size,
},
}
except Exception as e:
return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
"""Generate interactive visualizations of HTR processing results."""
try:
state = json.loads(processing_state)
if image is not None:
original_image = image
else:
image_data = base64.b64decode(state["image_base64"])
original_image = Image.open(io.BytesIO(image_data))
# Recreate the collection from the stored image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
original_image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
processed_collection = pipeline.run(collection)
viz_image = create_visualization(original_image, processed_collection, visualization_type, show_confidence, highlight_low_confidence)
img_buffer = io.BytesIO()
viz_image.save(img_buffer, format="PNG")
img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
return {
"success": True,
"visualization": {
"image_base64": img_base64,
"image_format": "PNG",
"visualization_type": visualization_type,
"dimensions": viz_image.size,
},
"metadata": {"visualization_type": visualization_type},
}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except Exception as e:
return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict:
"""Export HTR results to multiple formats using HTRflow's native export functionality."""
try:
state = json.loads(processing_state)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_data = base64.b64decode(state["image_base64"])
image = Image.open(io.BytesIO(image_data))
image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
processed_collection = pipeline.run(collection)
temp_dir = Path(tempfile.mkdtemp())
exports = {}
for fmt in output_formats:
export_dir = temp_dir / fmt
processed_collection.save(directory=str(export_dir), serializer=fmt)
export_files = []
for root, _, files in os.walk(export_dir):
for file in files:
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
export_files.append({"filename": file, "content": content})
except UnicodeDecodeError:
with open(file_path, 'rb') as f:
content = base64.b64encode(f.read()).decode('utf-8')
export_files.append({"filename": file, "content": content, "encoding": "base64"})
exports[fmt] = export_files
shutil.rmtree(temp_dir)
return {
"success": True,
"exports": exports,
"export_metadata": {
"formats_generated": output_formats,
"confidence_filter": confidence_filter,
"timestamp": datetime.now().isoformat(),
},
}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except Exception as e:
return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None}
def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict:
results = {"extracted_text": "", "text_lines": [], "confidence_scores": []}
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold:
results["text_lines"].append({
"text": node.text,
"confidence": node.confidence,
"bbox": getattr(node, "bbox", None),
})
results["extracted_text"] += node.text + "\n"
results["confidence_scores"].append(node.confidence)
results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
return results
def create_visualization(image, collection, visualization_type, show_confidence, highlight_low_confidence):
viz_image = image.copy()
draw = ImageDraw.Draw(viz_image)
try:
font = ImageFont.truetype("arial.ttf", 12)
except:
font = ImageFont.load_default()
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text:
bbox = node.bbox
confidence = getattr(node, "confidence", 1.0)
if visualization_type == "overlay":
color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
draw.rectangle(bbox, outline=color, width=2)
if show_confidence:
draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
elif visualization_type == "confidence_heatmap":
if confidence < 0.5:
color = (255, 0, 0, 100)
elif confidence < 0.8:
color = (255, 255, 0, 100)
else:
color = (0, 255, 0, 100)
overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle(bbox, fill=color)
viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
elif visualization_type == "text_regions":
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
color = colors[hash(str(bbox)) % len(colors)]
draw.rectangle(bbox, outline=color, width=3)
return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
def create_htrflow_mcp_server():
demo = gr.TabbedInterface(
[
gr.Interface(
fn=process_htr,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"),
gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"),
gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"),
],
outputs=gr.JSON(label="Processing Results"),
title="HTR Processing Tool",
description="Process handwritten text using configurable HTRflow pipelines",
api_name="process_htr",
),
gr.Interface(
fn=visualize_results,
inputs=[
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"),
gr.Checkbox(value=True, label="Show Confidence Scores"),
gr.Checkbox(value=True, label="Highlight Low Confidence"),
gr.Image(type="pil", label="Image (optional)"),
],
outputs=gr.JSON(label="Visualization Results"),
title="Results Visualization Tool",
description="Generate interactive visualizations of HTR results",
api_name="visualize_results",
),
gr.Interface(
fn=export_results,
inputs=[
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"),
gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
],
outputs=gr.JSON(label="Export Results"),
title="Export Tool",
description="Export HTR results to multiple formats",
api_name="export_results",
),
],
["HTR Processing", "Results Visualization", "Export Results"],
title="HTRflow MCP Server",
)
return demo
if __name__ == "__main__":
demo = create_htrflow_mcp_server()
demo.launch(mcp_server=True)