import os import sys import gradio as gr from PIL import Image import tempfile import shutil from pathlib import Path from kraken.lib import vgsl from kraken.lib import models from kraken import serialization import logging import numpy as np import cv2 from kraken import blla, rpred from kraken.containers import BaselineLine import json from jinja2 import Environment, FileSystemLoader import base64 import io from jinja2 import Template import re import time # Configure logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.WARNING) logging.getLogger('kraken').setLevel(logging.WARNING) logging.getLogger('kraken.serialization').setLevel(logging.WARNING) logging.getLogger('kraken.blla').setLevel(logging.WARNING) logging.getLogger('kraken.lib.models').setLevel(logging.WARNING) # Constants - Use relative paths for Hugging Face MODELS_DIR = Path("models") SEG_MODELS_DIR = MODELS_DIR / "seg" REC_MODELS_DIR = MODELS_DIR / "rec" # Embedded template PAGEXML_TEMPLATE = '''{%+ macro render_line(line) +%} {% if line.boundary %} {% endif %} {% if line.baseline %} {% endif %} {% if line.text is string %} {{ line.text|e }} {% else %} {% for segment in line.recognition %} {% if segment.boundary %} {% else %} {% endif %} {% for char in segment.recognition %} {{ char.text|e }} {% endfor %} {{ segment.text|e }} {% endfor %} {%+ if line.confidences|length %}{% for segment in line.recognition %}{{ segment.text|e }}{% endfor %}{% endif +%} {% endif %} {%+ endmacro %} kraken {{ metadata.version }} {{ page.date }} {{ page.date }} {% for entity in page.entities %} {% if entity.type == "region" %} {% if loop.previtem and loop.previtem.type == 'line' %} {% endif %} {% if entity.boundary %}{% endif %} {%- for line in entity.lines -%} {{ render_line(line) }} {%- endfor %} {% else %} {% if not loop.previtem or loop.previtem.type != 'line' %} {% endif %} {{ render_line(entity) }} {% if loop.last %} {% endif %} {% endif %} {% endfor %} ''' # Create Jinja environment TEMPLATE_DIR = Path("templates") TEMPLATE_DIR.mkdir(exist_ok=True) _ENV = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR))) def seg_rec_image(image_path, seg_model, rec_model, output_dir=None): try: im = Image.open(image_path) baseline_seg = blla.segment(im, model=seg_model) # Run recognition and collect full BaselineOCRRecord objects pred_it = rpred.rpred(network=rec_model, im=im, bounds=baseline_seg, pad=16) records = [record for record in pred_it] # Attach recognition results to segmentation lines for line, rec_line in zip(baseline_seg.lines, records): # Debug logging for recognition results logger.debug(f'Recognition result - Prediction: {rec_line.prediction}') logger.debug(f'Recognition result - Confidences: {rec_line.confidences}') # Ensure the line has both prediction and confidence values line.prediction = rec_line.prediction line.text = rec_line.prediction # Set text field for serialization # Store per-character confidences line.confidences = rec_line.confidences # Keep the list of confidences # Debug logging for line object logger.debug(f'Line {line.id} - Prediction: {line.prediction}') logger.debug(f'Line {line.id} - Confidences: {line.confidences}') # Construct PAGE-XML segmentation only data pagexml_seg_only = serialization.serialize(baseline_seg, image_size=im.size, template='pagexml', sub_line_segmentation=False) # Serialize with recognition results pagexml = serialization.serialize(baseline_seg, image_size=im.size, template='custom_pagexml', template_source='custom', sub_line_segmentation=False) base_name = os.path.splitext(os.path.basename(image_path))[0] if output_dir: os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, base_name + '.xml') else: output_path = os.path.splitext(image_path)[0] + '.xml' with open(output_path, 'w') as fp: fp.write(pagexml) print(f"✅ Segmented/recognized: {os.path.basename(image_path)} → {os.path.basename(output_path)}") except Exception as e: print(f"❌ Failed to process {image_path}: {e}") import traceback traceback.print_exc() # Create template files def create_templates(): """Create Jinja templates for visualization.""" # Image template with SVG for visualization image_template = """
{% for line in lines %} {% endfor %}
{% for line in lines %} {{ loop.index }}: {{ line.text }} {% if line.confidence %} ({{ "%.2f"|format(line.confidence) }}) {% endif %}
{% endfor %}
""" # Transcription template transcription_template = """
{% for line in lines %} {{ loop.index }}: {{ line.text }} {% if line.confidence %} ({{ "%.2f"|format(line.confidence) }}) {% endif %}
{% endfor %}
""" # Write templates with open(TEMPLATE_DIR / "image.html", "w") as f: f.write(image_template) with open(TEMPLATE_DIR / "transcription.html", "w") as f: f.write(transcription_template) def ensure_template_exists(): """Create the template file if it doesn't exist.""" template_path = os.path.join(os.path.dirname(__file__), 'custom_pagexml') if not os.path.exists(template_path): with open(template_path, 'w', encoding='utf-8') as f: f.write(PAGEXML_TEMPLATE) def get_model_files(directory): """Get list of .mlmodel files from directory.""" return [f for f in os.listdir(directory) if f.endswith('.mlmodel')] def load_models(): """Load all available models.""" seg_models = {} rec_models = {} # Load segmentation models for model_file in get_model_files(SEG_MODELS_DIR): try: model_path = os.path.join(SEG_MODELS_DIR, model_file) seg_models[model_file] = vgsl.TorchVGSLModel.load_model(model_path) except Exception as e: print(f"Error loading segmentation model {model_file}: {str(e)}") # Load recognition models for model_file in get_model_files(REC_MODELS_DIR): try: model_path = os.path.join(REC_MODELS_DIR, model_file) rec_models[model_file] = models.load_any(model_path) except Exception as e: print(f"Error loading recognition model {model_file}: {str(e)}") return seg_models, rec_models def process_image(image, seg_model, rec_model): """Process image and return segmentation and recognition results.""" # Run segmentation baseline_seg = blla.segment(image, model=seg_model) # Run recognition pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16) records = [record for record in pred_it] # Attach recognition results to segmentation lines for line, rec_line in zip(baseline_seg.lines, records): line.prediction = rec_line.prediction line.text = rec_line.prediction line.confidences = rec_line.confidences return baseline_seg def render_image(image, baseline_seg): """Render image with SVG overlay.""" # Convert image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") image_base64 = base64.b64encode(buffered.getvalue()).decode() # Get image dimensions width, height = image.size # Prepare lines data lines = [] for line in baseline_seg.lines: # Convert boundary points to SVG path boundary_points = [] for point in line.boundary: boundary_points.append(f"{point[0]},{point[1]}") # Convert baseline points to SVG path baseline_points = [] for point in line.baseline: baseline_points.append(f"{point[0]},{point[1]}") # Get text and determine direction text = line.text if hasattr(line, 'text') else '' # Check if text contains RTL characters (Hebrew, Arabic, etc.) rtl_chars = re.compile(r'[\u0591-\u07FF\u200F\u202B\u202E\uFB1D-\uFDFD\uFE70-\uFEFC\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]') is_rtl = bool(rtl_chars.search(text)) lines.append({ 'boundary': boundary_points, 'baseline': baseline_points, 'text': text, 'confidence': line.confidence if hasattr(line, 'confidence') else None, 'is_rtl': is_rtl }) # Render template template = """
{% for line in lines %} {% endfor %}
{% for line in lines %}
{{ loop.index }}: {{ line.text }} {% if line.confidence %} ({{ "%.2f"|format(line.confidence) }}) {% endif %}
{% endfor %}
""" return Template(template).render( width=width, height=height, image_base64=image_base64, lines=lines ) def get_example_images(): """Get list of example images from the examples directory.""" examples_dir = Path(__file__).parent / "examples" if not examples_dir.exists(): return [] # Combine both glob patterns into a single list return [str(f) for f in list(examples_dir.glob("*.jpg")) + list(examples_dir.glob("*.png"))] def process_and_visualize(image, seg_model_name, rec_model_name, progress=gr.Progress()): try: if image is None: yield "❌ Please upload an image first.", None, None, None, None, None return yield "🔄 Starting processing...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) progress(0.1, desc="Loading models...") yield "📦 Loading models...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) seg_models, rec_models = load_models() seg_model = seg_models[seg_model_name] rec_model = rec_models[rec_model_name] progress(0.3, desc="Running Segmentation...") yield "✂️ Running segmentation...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) baseline_seg = blla.segment(image, model=seg_model) progress(0.6, desc="Running Recognition...") yield "🔠 Running text recognition...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16) records = [record for record in pred_it] for line, rec_line in zip(baseline_seg.lines, records): line.prediction = rec_line.prediction line.text = rec_line.prediction line.confidences = rec_line.confidences progress(0.85, desc="Generating PageXML...") yield "📝 Generating PageXML output...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) with tempfile.TemporaryDirectory() as temp_dir: input_path = os.path.join(temp_dir, "temp.png") image.save(input_path) seg_rec_image(input_path, seg_model, rec_model, temp_dir) output_xml = os.path.join(temp_dir, "temp.xml") xml_content = open(output_xml, 'r', encoding='utf-8').read() if os.path.exists(output_xml) else "⚠️ Error generating XML output." progress(1.0, desc="Rendering results...") yield "✅ Done! Switch to visualization!", render_image(image, baseline_seg), xml_content, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) except Exception as e: yield f"❌ Error: {str(e)}", None, None, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) def main(): # Create necessary directories and templates SEG_MODELS_DIR.mkdir(parents=True, exist_ok=True) REC_MODELS_DIR.mkdir(parents=True, exist_ok=True) ensure_template_exists() create_templates() # Load available models seg_models, rec_models = load_models() if not seg_models: print("No segmentation models found in app/models/seg. Please add .mlmodel files.") return if not rec_models: print("No recognition models found in app/models/rec. Please add .mlmodel files.") return # Create Gradio interface with gr.Blocks(title="Kraken OCR on Samaritan manuscripts") as demo: gr.Markdown("# Kraken OCR on Samaritan manuscripts") gr.Markdown("Upload an image and select models to process it.") with gr.Tabs() as tabs: with gr.Tab("Upload Image") as upload_tab: with gr.Row(): with gr.Column(scale=2): image_input = gr.Image(type="pil", label="Input Image", height=400) with gr.Row(): seg_model = gr.Dropdown(choices=list(seg_models.keys()), label="Segmentation Model", value=list(seg_models.keys())[0]) rec_model = gr.Dropdown(choices=list(rec_models.keys()), label="Recognition Model", value=list(rec_models.keys())[0]) process_btn = gr.Button("Process Image") status_box = gr.Markdown("", visible=True) with gr.Column(scale=1): gr.Markdown("### Example Images") examples = gr.Gallery( get_example_images(), show_label=False, interactive=True, allow_preview=False, object_fit="cover", columns=2, height=400, elem_classes="example-gallery" ) with gr.Tab("Visualization", interactive=False) as vis_tab: visualization_output = gr.HTML(label="Visualization") with gr.Tab("PageXML", interactive=False) as xml_tab: xml_output = gr.Textbox(label="PageXML", lines=20, max_lines=50, show_copy_button=True) # Add custom CSS for the gallery gr.HTML(""" """) process_btn.click( process_and_visualize, inputs=[image_input, seg_model, rec_model], outputs=[status_box, visualization_output, xml_output, vis_tab, xml_tab, upload_tab], show_progress=True ).then( lambda: gr.Tabs(selected="Visualization"), outputs=tabs ) # Example image selection handler def select_example(evt: gr.SelectData): if not examples.value: return None selected = examples.value[evt.index] return selected["image"]["path"] examples.select( select_example, None, image_input ) demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": main()