Spaces:
Running
on
T4
Running
on
T4
File size: 3,087 Bytes
b9dea2c ee94965 49f26e0 5e6a772 ee94965 b9dea2c 82673a2 ee94965 b9dea2c 48bbaa9 ee94965 48bbaa9 b9dea2c 48bbaa9 82673a2 8ef414c ee94965 48bbaa9 ee94965 b9dea2c b3bb9ad ee94965 b3bb9ad 82673a2 48bbaa9 82673a2 b3bb9ad 82673a2 b3bb9ad 48bbaa9 c8dea56 b3bb9ad 82673a2 b3bb9ad b9dea2c 48bbaa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
import json
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_detection
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.postprocessing.heatmap import draw_polys_on_image
# Load models and processors
det_model, det_processor = load_det_model(), load_det_processor()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
# Create a dictionary to map language names to codes
with open("languages.json", "r") as file:
languages = json.load(file)
language_dict = {name: code for name, code in languages.items()}
# Use the language names for the dropdown choices
language_options = list(language_dict.keys())
def ocr_function(img, lang_name):
# Get the language code from the dictionary
lang_code = language_dict[lang_name]
predictions = run_ocr([img], [lang_code], det_model, det_processor, rec_model, rec_processor)
# Assuming predictions is a list of dictionaries, one per image
if predictions:
img_with_text = draw_polys_on_image(predictions[0]["polys"], img)
return img_with_text, predictions[0]
else:
return img, {"error": "No text detected"}
def text_line_detection_function(img):
preds = batch_detection([img], det_model, det_processor)[0]
img_with_lines = draw_polys_on_image(preds["polygons"], img)
return img_with_lines, preds
with gr.Blocks() as app:
gr.Markdown("# Surya OCR e Detecção de Linhas de Texto")
with gr.Tab("OCR"):
with gr.Column():
ocr_input_image = gr.Image(label="Input Image for OCR", type="pil")
ocr_language_selector = gr.Dropdown(label="Select Language for OCR", choices=language_options, value="English")
ocr_run_button = gr.Button("Run OCR")
with gr.Column():
ocr_output_image = gr.Image(label="OCR Output Image", type="pil", interactive=False)
ocr_text_output = gr.TextArea(label="Recognized Text")
# Pass the input image and the language name to the ocr_function
ocr_run_button.click(fn=ocr_function, inputs=[ocr_input_image, ocr_language_selector.value], outputs=[ocr_output_image, ocr_text_output])
with gr.Tab("Detecção de Linhas de Texto"):
with gr.Column():
detection_input_image = gr.Image(label="Imagem de Entrada para Detecção", type="pil")
detection_run_button = gr.Button("Executar Detecção de Linhas de Texto")
with gr.Column():
detection_output_image = gr.Image(label="Imagem de Saída da Detecção", type="pil", interactive=False)
detection_json_output = gr.JSON(label="Saída JSON da Detecção")
detection_run_button.click(fn=text_line_detection_function, inputs=detection_input_image, outputs=[detection_output_image, detection_json_output])
if __name__ == "__main__":
app.launch() |