OCR / app.py
rahul7star's picture
Update app.py
e9b8d71 verified
raw
history blame
2.89 kB
import gradio as gr
from PIL import Image, ImageDraw
import requests
from io import BytesIO
import numpy as np
import json
import tempfile
import easyocr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# TrOCR model for recognition
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
# EasyOCR reader for bounding boxes
reader = easyocr.Reader(['en'])
def load_image(image_file, image_url):
if image_file:
return image_file
elif image_url:
response = requests.get(image_url)
return Image.open(BytesIO(response.content)).convert("RGB")
return None
def detect_text_trocr_json(image_file, image_url):
image = load_image(image_file, image_url)
if image is None:
return None, "No image provided.", None
results = reader.readtext(np.array(image))
draw = ImageDraw.Draw(image)
words_json = []
for bbox, _, conf in results:
# Convert coordinates to float for JSON serialization
x_coords = [float(point[0]) for point in bbox]
y_coords = [float(point[1]) for point in bbox]
x_min, y_min = min(x_coords), min(y_coords)
x_max, y_max = max(x_coords), max(y_coords)
# Crop each word for recognition
word_crop = image.crop((x_min, y_min, x_max, y_max))
pixel_values = processor(images=word_crop, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
words_json.append({
"text": text,
"bbox": [x_min, y_min, x_max, y_max],
"confidence": float(conf)
})
# Treat words as paragraphs for simplicity
paragraphs_json = words_json.copy()
output_json = {
"words": words_json,
"paragraphs": paragraphs_json
}
json_str = json.dumps(output_json, indent=2)
# Save JSON to a temporary file for Gradio download
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w')
tmp_file.write(json_str)
tmp_file.close()
return image, json_str, tmp_file.name
iface = gr.Interface(
fn=detect_text_trocr_json,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Image URL (optional)")
],
outputs=[
gr.Image(type="pil", label="Annotated Image"),
gr.Textbox(label="Text & Bounding Boxes (JSON)"),
gr.File(label="Download JSON")
],
title="Handwritten OCR with TrOCR + Bounding Boxes",
description="Detect handwritten text and bounding boxes. Uses TrOCR for recognition and EasyOCR for detection."
)
if __name__ == "__main__":
iface.launch()