Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import gradio as gr | |
from PIL import Image | |
import tempfile | |
import shutil | |
from pathlib import Path | |
from kraken.lib import vgsl | |
from kraken.lib import models | |
from kraken import serialization | |
import logging | |
import numpy as np | |
import cv2 | |
from kraken import blla, rpred | |
from kraken.containers import BaselineLine | |
import json | |
from jinja2 import Environment, FileSystemLoader | |
import base64 | |
import io | |
from jinja2 import Template | |
import re | |
import time | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.WARNING) | |
logging.getLogger('kraken').setLevel(logging.WARNING) | |
logging.getLogger('kraken.serialization').setLevel(logging.WARNING) | |
logging.getLogger('kraken.blla').setLevel(logging.WARNING) | |
logging.getLogger('kraken.lib.models').setLevel(logging.WARNING) | |
# Constants - Use relative paths for Hugging Face | |
MODELS_DIR = Path("models") | |
SEG_MODELS_DIR = MODELS_DIR / "seg" | |
REC_MODELS_DIR = MODELS_DIR / "rec" | |
# Embedded template | |
PAGEXML_TEMPLATE = '''{%+ macro render_line(line) +%} | |
<TextLine id="{{ line.id }}" {% if line.tags and "type" in line.tags %}custom="structure {type:{{ line.tags["type"] }};}"{% endif %}> | |
{% if line.boundary %} | |
<Coords points="{% for point in line.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
{% endif %} | |
{% if line.baseline %} | |
<Baseline points="{% for point in line.baseline %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
{% endif %} | |
{% if line.text is string %} | |
<TextEquiv{% if line.confidences|length %} conf="{{ (line.confidences|sum / line.confidences|length)|round(4) }}"{% endif %}><Unicode>{{ line.text|e }}</Unicode></TextEquiv> | |
{% else %} | |
{% for segment in line.recognition %} | |
<Word id="segment_{{ segment.index }}"> | |
{% if segment.boundary %} | |
<Coords points="{% for point in segment.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
{% else %} | |
<Coords points="{{ segment.bbox[0] }},{{ segment.bbox[1] }} {{ segment.bbox[0] }},{{ segment.bbox[3] }} {{ segment.bbox[2] }},{{ segment.bbox[3] }} {{ segment.bbox[2] }},{{ segment.bbox[1] }}"/> | |
{% endif %} | |
{% for char in segment.recognition %} | |
<Glyph id="char_{{ char.index }}"> | |
<Coords points="{% for point in char.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
<TextEquiv conf="{{ char.confidence|round(4) }}"><Unicode>{{ char.text|e }}</Unicode></TextEquiv> | |
</Glyph> | |
{% endfor %} | |
<TextEquiv conf="{{ (segment.confidences|sum / segment.confidences|length)|round(4) }}"><Unicode>{{ segment.text|e }}</Unicode></TextEquiv> | |
</Word> | |
{% endfor %} | |
{%+ if line.confidences|length %}<TextEquiv conf="{{ (line.confidences|sum / line.confidences|length)|round(4) }}"><Unicode>{% for segment in line.recognition %}{{ segment.text|e }}{% endfor %}</Unicode></TextEquiv>{% endif +%} | |
{% endif %} | |
</TextLine> | |
{%+ endmacro %} | |
<?xml version="1.0" encoding="UTF-8"?> | |
<PcGts xmlns="http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd"> | |
<Metadata> | |
<Creator>kraken {{ metadata.version }}</Creator> | |
<Created>{{ page.date }}</Created> | |
<LastChange>{{ page.date }}</LastChange> | |
</Metadata> | |
<Page imageFilename="{{ page.name }}" imageWidth="{{ page.size[0] }}" imageHeight="{{ page.size[1] }}" {% if page.base_dir %}readingDirection="{{ page.base_dir }}"{% endif %}> | |
{% for entity in page.entities %} | |
{% if entity.type == "region" %} | |
{% if loop.previtem and loop.previtem.type == 'line' %} | |
</TextRegion> | |
{% endif %} | |
<TextRegion id="{{ entity.id }}" {% if entity.tags and "type" in entity.tags %}custom="structure {type:{{ entity.tags["type"] }};}"{% endif %}> | |
{% if entity.boundary %}<Coords points="{% for point in entity.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/>{% endif %} | |
{%- for line in entity.lines -%} | |
{{ render_line(line) }} | |
{%- endfor %} | |
</TextRegion> | |
{% else %} | |
{% if not loop.previtem or loop.previtem.type != 'line' %} | |
<TextRegion id="textblock_{{ loop.index }}"> | |
<Coords points="0,0 0,{{ page.size[1] }} {{ page.size[0] }},{{ page.size[1] }} {{ page.size[0] }},0"/> | |
{% endif %} | |
{{ render_line(entity) }} | |
{% if loop.last %} | |
</TextRegion> | |
{% endif %} | |
{% endif %} | |
{% endfor %} | |
</Page> | |
</PcGts>''' | |
# Create Jinja environment | |
TEMPLATE_DIR = Path("templates") | |
TEMPLATE_DIR.mkdir(exist_ok=True) | |
_ENV = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR))) | |
def seg_rec_image(image_path, seg_model, rec_model, output_dir=None): | |
try: | |
im = Image.open(image_path) | |
baseline_seg = blla.segment(im, model=seg_model) | |
# Run recognition and collect full BaselineOCRRecord objects | |
pred_it = rpred.rpred(network=rec_model, im=im, bounds=baseline_seg, pad=16) | |
records = [record for record in pred_it] | |
# Attach recognition results to segmentation lines | |
for line, rec_line in zip(baseline_seg.lines, records): | |
# Debug logging for recognition results | |
logger.debug(f'Recognition result - Prediction: {rec_line.prediction}') | |
logger.debug(f'Recognition result - Confidences: {rec_line.confidences}') | |
# Ensure the line has both prediction and confidence values | |
line.prediction = rec_line.prediction | |
line.text = rec_line.prediction # Set text field for serialization | |
# Store per-character confidences | |
line.confidences = rec_line.confidences # Keep the list of confidences | |
# Debug logging for line object | |
logger.debug(f'Line {line.id} - Prediction: {line.prediction}') | |
logger.debug(f'Line {line.id} - Confidences: {line.confidences}') | |
# Construct PAGE-XML segmentation only data | |
pagexml_seg_only = serialization.serialize(baseline_seg, image_size=im.size, template='pagexml', sub_line_segmentation=False) | |
# Serialize with recognition results | |
pagexml = serialization.serialize(baseline_seg, | |
image_size=im.size, | |
template='custom_pagexml', | |
template_source='custom', | |
sub_line_segmentation=False) | |
base_name = os.path.splitext(os.path.basename(image_path))[0] | |
if output_dir: | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, base_name + '.xml') | |
else: | |
output_path = os.path.splitext(image_path)[0] + '.xml' | |
with open(output_path, 'w') as fp: | |
fp.write(pagexml) | |
print(f"✅ Segmented/recognized: {os.path.basename(image_path)} → {os.path.basename(output_path)}") | |
except Exception as e: | |
print(f"❌ Failed to process {image_path}: {e}") | |
import traceback | |
traceback.print_exc() | |
# Create template files | |
def create_templates(): | |
"""Create Jinja templates for visualization.""" | |
# Image template with SVG for visualization | |
image_template = """ | |
<div class="visualization-container"> | |
<div class="image-container"> | |
<svg width="{{ width }}" height="{{ height }}" viewBox="0 0 {{ width }} {{ height }}"> | |
<image href="data:image/png;base64,{{ image_base64 }}" width="{{ width }}" height="{{ height }}"/> | |
{% for line in lines %} | |
<a class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
<path class="line-boundary" d="M {{ line.boundary|join(' L ') }} Z" fill="rgba(0, 128, 255, 0.2)" stroke="none"/> | |
<path class="line-baseline" d="M {{ line.baseline|join(' L ') }}" stroke="red" stroke-width="1" fill="none"/> | |
</a> | |
{% endfor %} | |
</svg> | |
</div> | |
<div class="transcription-container"> | |
{% for line in lines %} | |
<span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
<span class="line-number">{{ loop.index }}:</span> | |
<span class="line-text">{{ line.text }}</span> | |
{% if line.confidence %} | |
<span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
{% endif %} | |
</span> | |
<br> | |
{% endfor %} | |
</div> | |
</div> | |
<style> | |
.visualization-container { | |
display: flex; | |
gap: 20px; | |
max-height: 1000px; | |
} | |
.image-container { | |
flex: 2; | |
overflow: auto; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
.image-container svg { | |
display: block; | |
width: 100%; | |
height: auto; | |
max-width: 100%; | |
} | |
.transcription-container { | |
flex: 1; | |
overflow-y: auto; | |
padding: 10px; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
/* Synchronize scrolling between containers */ | |
.image-container, .transcription-container { | |
scroll-behavior: smooth; | |
} | |
.image-container::-webkit-scrollbar, .transcription-container::-webkit-scrollbar { | |
width: 8px; | |
} | |
.image-container::-webkit-scrollbar-track, .transcription-container::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
} | |
.image-container::-webkit-scrollbar-thumb, .transcription-container::-webkit-scrollbar-thumb { | |
background: #888; | |
border-radius: 4px; | |
} | |
.image-container::-webkit-scrollbar-thumb:hover, .transcription-container::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
.textline { | |
padding: 5px; | |
cursor: pointer; | |
display: inline-block; | |
unicode-bidi: bidi-override; | |
} | |
.textline:hover, | |
.textline.highlighted { | |
background-color: rgba(0, 128, 255, 0.1); | |
} | |
.textline:hover .line-boundary, | |
.textline.highlighted .line-boundary { | |
fill: rgba(0, 255, 255, 0.3); | |
} | |
.textline:hover .line-baseline, | |
.textline.highlighted .line-baseline { | |
stroke: yellow; | |
} | |
.line-number { | |
color: #666; | |
margin-right: 5px; | |
} | |
.line-confidence { | |
color: #888; | |
font-size: 0.9em; | |
margin-left: 5px; | |
} | |
/* RTL text support */ | |
.textline[dir="rtl"] { | |
text-align: right; | |
} | |
.textline[dir="ltr"] { | |
text-align: left; | |
} | |
</style> | |
<script> | |
// Synchronize scrolling between containers | |
const imageContainer = document.querySelector('.image-container'); | |
const textContainer = document.querySelector('.transcription-container'); | |
function syncScroll(source, target) { | |
const ratio = target.scrollHeight / source.scrollHeight; | |
target.scrollTop = source.scrollTop * ratio; | |
} | |
imageContainer.addEventListener('scroll', () => syncScroll(imageContainer, textContainer)); | |
textContainer.addEventListener('scroll', () => syncScroll(textContainer, imageContainer)); | |
// Function to detect text direction | |
function detectTextDirection(text) { | |
const rtlChars = /[\u0591-\u07FF\u200F\u202B\u202E\uFB1D-\uFDFD\uFE70-\uFEFC]/; | |
return rtlChars.test(text) ? 'rtl' : 'ltr'; | |
} | |
// Add direction attribute to text lines | |
function updateTextDirections() { | |
document.querySelectorAll('.textline').forEach(line => { | |
const text = line.textContent; | |
line.setAttribute('dir', detectTextDirection(text)); | |
}); | |
} | |
// Update text directions when visualization changes | |
const observer = new MutationObserver(updateTextDirections); | |
observer.observe(document.body, { childList: true, subtree: true }); | |
</script> | |
""" | |
# Transcription template | |
transcription_template = """ | |
<div class="transcription-container" style="max-height: 600px; overflow-y: auto;"> | |
{% for line in lines %} | |
<span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
<span class="line-number">{{ loop.index }}:</span> | |
<span class="line-text">{{ line.text }}</span> | |
{% if line.confidence %} | |
<span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
{% endif %} | |
</span> | |
<br> | |
{% endfor %} | |
</div> | |
<style> | |
.textline { | |
padding: 5px; | |
cursor: pointer; | |
display: inline-block; | |
} | |
.textline:hover, | |
.textline.highlighted { | |
background-color: rgba(0, 128, 255, 0.1); | |
} | |
.line-number { | |
color: #666; | |
margin-right: 5px; | |
} | |
.line-confidence { | |
color: #888; | |
font-size: 0.9em; | |
margin-left: 5px; | |
} | |
</style> | |
""" | |
# Write templates | |
with open(TEMPLATE_DIR / "image.html", "w") as f: | |
f.write(image_template) | |
with open(TEMPLATE_DIR / "transcription.html", "w") as f: | |
f.write(transcription_template) | |
def ensure_template_exists(): | |
"""Create the template file if it doesn't exist.""" | |
template_path = os.path.join(os.path.dirname(__file__), 'custom_pagexml') | |
if not os.path.exists(template_path): | |
with open(template_path, 'w', encoding='utf-8') as f: | |
f.write(PAGEXML_TEMPLATE) | |
def get_model_files(directory): | |
"""Get list of .mlmodel files from directory.""" | |
return [f for f in os.listdir(directory) if f.endswith('.mlmodel')] | |
def load_models(): | |
"""Load all available models.""" | |
seg_models = {} | |
rec_models = {} | |
# Load segmentation models | |
for model_file in get_model_files(SEG_MODELS_DIR): | |
try: | |
model_path = os.path.join(SEG_MODELS_DIR, model_file) | |
seg_models[model_file] = vgsl.TorchVGSLModel.load_model(model_path) | |
except Exception as e: | |
print(f"Error loading segmentation model {model_file}: {str(e)}") | |
# Load recognition models | |
for model_file in get_model_files(REC_MODELS_DIR): | |
try: | |
model_path = os.path.join(REC_MODELS_DIR, model_file) | |
rec_models[model_file] = models.load_any(model_path) | |
except Exception as e: | |
print(f"Error loading recognition model {model_file}: {str(e)}") | |
return seg_models, rec_models | |
def process_image(image, seg_model, rec_model): | |
"""Process image and return segmentation and recognition results.""" | |
# Run segmentation | |
baseline_seg = blla.segment(image, model=seg_model) | |
# Run recognition | |
pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16) | |
records = [record for record in pred_it] | |
# Attach recognition results to segmentation lines | |
for line, rec_line in zip(baseline_seg.lines, records): | |
line.prediction = rec_line.prediction | |
line.text = rec_line.prediction | |
line.confidences = rec_line.confidences | |
return baseline_seg | |
def render_image(image, baseline_seg): | |
"""Render image with SVG overlay.""" | |
# Convert image to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
# Get image dimensions | |
width, height = image.size | |
# Prepare lines data | |
lines = [] | |
for line in baseline_seg.lines: | |
# Convert boundary points to SVG path | |
boundary_points = [] | |
for point in line.boundary: | |
boundary_points.append(f"{point[0]},{point[1]}") | |
# Convert baseline points to SVG path | |
baseline_points = [] | |
for point in line.baseline: | |
baseline_points.append(f"{point[0]},{point[1]}") | |
# Get text and determine direction | |
text = line.text if hasattr(line, 'text') else '' | |
# Check if text contains RTL characters (Hebrew, Arabic, etc.) | |
rtl_chars = re.compile(r'[\u0591-\u07FF\u200F\u202B\u202E\uFB1D-\uFDFD\uFE70-\uFEFC\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]') | |
is_rtl = bool(rtl_chars.search(text)) | |
lines.append({ | |
'boundary': boundary_points, | |
'baseline': baseline_points, | |
'text': text, | |
'confidence': line.confidence if hasattr(line, 'confidence') else None, | |
'is_rtl': is_rtl | |
}) | |
# Render template | |
template = """ | |
<div class="visualization-container"> | |
<div class="image-container"> | |
<svg width="{{ width }}" height="{{ height }}" viewBox="0 0 {{ width }} {{ height }}"> | |
<image href="data:image/png;base64,{{ image_base64 }}" width="{{ width }}" height="{{ height }}"/> | |
{% for line in lines %} | |
<a class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
<path class="line-boundary" d="M {{ line.boundary|join(' L ') }} Z" fill="rgba(0, 128, 255, 0.2)" stroke="none"/> | |
<path class="line-baseline" d="M {{ line.baseline|join(' L ') }}" stroke="red" stroke-width="1" fill="none"/> | |
</a> | |
{% endfor %} | |
</svg> | |
</div> | |
<div class="transcription-container"> | |
{% for line in lines %} | |
<div class="textline-container {% if line.is_rtl %}rtl{% else %}ltr{% endif %}"> | |
<span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
<span class="line-number">{{ loop.index }}:</span> | |
<span class="line-text">{{ line.text }}</span> | |
{% if line.confidence %} | |
<span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
{% endif %} | |
</span> | |
</div> | |
{% endfor %} | |
</div> | |
</div> | |
<style> | |
.visualization-container { | |
display: flex; | |
gap: 20px; | |
max-height: 1000px; | |
} | |
.image-container { | |
flex: 2; | |
overflow: auto; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
.image-container svg { | |
display: block; | |
width: 100%; | |
height: auto; | |
max-width: 100%; | |
} | |
.transcription-container { | |
flex: 1; | |
overflow-y: auto; | |
padding: 10px; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
/* Synchronize scrolling between containers */ | |
.image-container, .transcription-container { | |
scroll-behavior: smooth; | |
} | |
.image-container::-webkit-scrollbar, .transcription-container::-webkit-scrollbar { | |
width: 8px; | |
} | |
.image-container::-webkit-scrollbar-track, .transcription-container::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
} | |
.image-container::-webkit-scrollbar-thumb, .transcription-container::-webkit-scrollbar-thumb { | |
background: #888; | |
border-radius: 4px; | |
} | |
.image-container::-webkit-scrollbar-thumb:hover, .transcription-container::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
.textline-container { | |
padding: 5px; | |
margin: 2px 0; | |
border-radius: 4px; | |
} | |
.textline-container.rtl { | |
direction: rtl; | |
text-align: right; | |
} | |
.textline-container.ltr { | |
direction: ltr; | |
text-align: left; | |
} | |
.textline { | |
cursor: pointer; | |
display: inline-block; | |
width: 100%; | |
} | |
.textline:hover, | |
.textline.highlighted { | |
background-color: rgba(0, 128, 255, 0.1); | |
} | |
.textline:hover .line-boundary, | |
.textline.highlighted .line-boundary { | |
fill: rgba(0, 255, 255, 0.3); | |
} | |
.textline:hover .line-baseline, | |
.textline.highlighted .line-baseline { | |
stroke: yellow; | |
} | |
.line-number { | |
color: #666; | |
margin-right: 5px; | |
} | |
.line-text { | |
unicode-bidi: bidi-override; | |
} | |
.line-confidence { | |
color: #888; | |
font-size: 0.9em; | |
margin-left: 5px; | |
} | |
</style> | |
<script> | |
// Synchronize scrolling between containers | |
const imageContainer = document.querySelector('.image-container'); | |
const textContainer = document.querySelector('.transcription-container'); | |
function syncScroll(source, target) { | |
const ratio = target.scrollHeight / source.scrollHeight; | |
target.scrollTop = source.scrollTop * ratio; | |
} | |
imageContainer.addEventListener('scroll', () => syncScroll(imageContainer, textContainer)); | |
textContainer.addEventListener('scroll', () => syncScroll(textContainer, imageContainer)); | |
</script> | |
""" | |
return Template(template).render( | |
width=width, | |
height=height, | |
image_base64=image_base64, | |
lines=lines | |
) | |
def get_example_images(): | |
"""Get list of example images from the examples directory.""" | |
examples_dir = Path(__file__).parent / "examples" | |
if not examples_dir.exists(): | |
return [] | |
# Combine both glob patterns into a single list | |
return [str(f) for f in list(examples_dir.glob("*.jpg")) + list(examples_dir.glob("*.png"))] | |
def process_and_visualize(image, seg_model_name, rec_model_name, progress=gr.Progress()): | |
try: | |
if image is None: | |
yield "❌ Please upload an image first.", None, None, None, None, None | |
return | |
yield "🔄 Starting processing...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
progress(0.1, desc="Loading models...") | |
yield "📦 Loading models...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
seg_models, rec_models = load_models() | |
seg_model = seg_models[seg_model_name] | |
rec_model = rec_models[rec_model_name] | |
progress(0.3, desc="Running Segmentation...") | |
yield "✂️ Running segmentation...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
baseline_seg = blla.segment(image, model=seg_model) | |
progress(0.6, desc="Running Recognition...") | |
yield "🔠 Running text recognition...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16) | |
records = [record for record in pred_it] | |
for line, rec_line in zip(baseline_seg.lines, records): | |
line.prediction = rec_line.prediction | |
line.text = rec_line.prediction | |
line.confidences = rec_line.confidences | |
progress(0.85, desc="Generating PageXML...") | |
yield "📝 Generating PageXML output...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
with tempfile.TemporaryDirectory() as temp_dir: | |
input_path = os.path.join(temp_dir, "temp.png") | |
image.save(input_path) | |
seg_rec_image(input_path, seg_model, rec_model, temp_dir) | |
output_xml = os.path.join(temp_dir, "temp.xml") | |
xml_content = open(output_xml, 'r', encoding='utf-8').read() if os.path.exists(output_xml) else "⚠️ Error generating XML output." | |
progress(1.0, desc="Rendering results...") | |
yield "✅ Done! Switch to visualization!", render_image(image, baseline_seg), xml_content, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
except Exception as e: | |
yield f"❌ Error: {str(e)}", None, None, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
def main(): | |
# Create necessary directories and templates | |
SEG_MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
REC_MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
ensure_template_exists() | |
create_templates() | |
# Load available models | |
seg_models, rec_models = load_models() | |
if not seg_models: | |
print("No segmentation models found in app/models/seg. Please add .mlmodel files.") | |
return | |
if not rec_models: | |
print("No recognition models found in app/models/rec. Please add .mlmodel files.") | |
return | |
# Create Gradio interface | |
with gr.Blocks(title="Kraken OCR on Samaritan manuscripts") as demo: | |
gr.Markdown("# Kraken OCR on Samaritan manuscripts") | |
gr.Markdown("Upload an image and select models to process it.") | |
with gr.Tabs() as tabs: | |
with gr.Tab("Upload Image") as upload_tab: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input = gr.Image(type="pil", label="Input Image", height=400) | |
with gr.Row(): | |
seg_model = gr.Dropdown(choices=list(seg_models.keys()), label="Segmentation Model", value=list(seg_models.keys())[0]) | |
rec_model = gr.Dropdown(choices=list(rec_models.keys()), label="Recognition Model", value=list(rec_models.keys())[0]) | |
process_btn = gr.Button("Process Image") | |
status_box = gr.Markdown("", visible=True) | |
with gr.Column(scale=1): | |
gr.Markdown("### Example Images") | |
examples = gr.Gallery( | |
get_example_images(), | |
show_label=False, | |
interactive=True, | |
allow_preview=False, | |
object_fit="cover", | |
columns=2, | |
height=400, | |
elem_classes="example-gallery" | |
) | |
with gr.Tab("Visualization", interactive=False) as vis_tab: | |
visualization_output = gr.HTML(label="Visualization") | |
with gr.Tab("PageXML", interactive=False) as xml_tab: | |
xml_output = gr.Textbox(label="PageXML", lines=20, max_lines=50, show_copy_button=True) | |
# Add custom CSS for the gallery | |
gr.HTML(""" | |
<style> | |
.example-gallery { | |
overflow-y: auto !important; | |
max-height: 400px !important; | |
} | |
.example-gallery img { | |
width: 100% !important; | |
height: 150px !important; | |
object-fit: cover !important; | |
border-radius: 4px !important; | |
cursor: pointer !important; | |
transition: transform 0.2s !important; | |
} | |
.example-gallery img:hover { | |
transform: scale(1.05) !important; | |
} | |
</style> | |
""") | |
process_btn.click( | |
process_and_visualize, | |
inputs=[image_input, seg_model, rec_model], | |
outputs=[status_box, visualization_output, xml_output, vis_tab, xml_tab, upload_tab], | |
show_progress=True | |
).then( | |
lambda: gr.Tabs(selected="Visualization"), | |
outputs=tabs | |
) | |
# Example image selection handler | |
def select_example(evt: gr.SelectData): | |
if not examples.value: | |
return None | |
selected = examples.value[evt.index] | |
return selected["image"]["path"] | |
examples.select( | |
select_example, | |
None, | |
image_input | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
main() | |