Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw | |
import requests | |
from io import BytesIO | |
import numpy as np | |
import json | |
import tempfile | |
import easyocr | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
# TrOCR model for recognition | |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
# EasyOCR reader for bounding boxes | |
reader = easyocr.Reader(['en']) | |
def load_image(image_file, image_url): | |
if image_file: | |
return image_file | |
elif image_url: | |
response = requests.get(image_url) | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
return None | |
def detect_text_trocr_json(image_file, image_url): | |
image = load_image(image_file, image_url) | |
if image is None: | |
return None, "No image provided.", None | |
results = reader.readtext(np.array(image)) | |
draw = ImageDraw.Draw(image) | |
words_json = [] | |
for bbox, _, conf in results: | |
# Convert coordinates to float for JSON serialization | |
x_coords = [float(point[0]) for point in bbox] | |
y_coords = [float(point[1]) for point in bbox] | |
x_min, y_min = min(x_coords), min(y_coords) | |
x_max, y_max = max(x_coords), max(y_coords) | |
# Crop each word for recognition | |
word_crop = image.crop((x_min, y_min, x_max, y_max)) | |
pixel_values = processor(images=word_crop, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2) | |
words_json.append({ | |
"text": text, | |
"bbox": [x_min, y_min, x_max, y_max], | |
"confidence": float(conf) | |
}) | |
# Treat words as paragraphs for simplicity | |
paragraphs_json = words_json.copy() | |
output_json = { | |
"words": words_json, | |
"paragraphs": paragraphs_json | |
} | |
json_str = json.dumps(output_json, indent=2) | |
# Save JSON to a temporary file for Gradio download | |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') | |
tmp_file.write(json_str) | |
tmp_file.close() | |
return image, json_str, tmp_file.name | |
iface = gr.Interface( | |
fn=detect_text_trocr_json, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(label="Image URL (optional)") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Annotated Image"), | |
gr.Textbox(label="Text & Bounding Boxes (JSON)"), | |
gr.File(label="Download JSON") | |
], | |
title="Handwritten OCR with TrOCR + Bounding Boxes", | |
description="Detect handwritten text and bounding boxes. Uses TrOCR for recognition and EasyOCR for detection." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |