File size: 2,759 Bytes
0ce48e6
 
cbe0e5f
aa9e906
cbe0e5f
 
0ce48e6
aa9e906
cbe0e5f
 
0ce48e6
 
cbe0e5f
 
 
 
 
 
 
 
 
 
0ce48e6
cbe0e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ce48e6
 
edbb0bf
cbe0e5f
 
 
 
 
0ce48e6
cbe0e5f
 
 
 
 
 
 
0ce48e6
cbe0e5f
 
 
 
 
 
 
 
aa9e906
cbe0e5f
 
 
 
edbb0bf
aa9e906
cbe0e5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image, ImageOps
import numpy as np
import io
import base64

# Load model và processor
name = "chanelcolgate/trocr-base-printed_captcha_ocr"
model = VisionEncoderDecoderModel.from_pretrained(name)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")


def prepare_image(pil_image):
    """Xử lý nền trắng nếu ảnh có nền trong suốt"""
    if pil_image.mode in ("RGBA", "LA"):
        background = Image.new("RGB", pil_image.size, (255, 255, 255))
        background.paste(pil_image, mask=pil_image.split()[-1])
        return background
    return pil_image.convert("RGB")


def process_image(image):
    pil_image = Image.fromarray(image)
    image_clean = prepare_image(pil_image)

    pixel_values = processor(image_clean, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return image_clean, generated_text


def process_base64(base64_str):
    # Tách phần prefix (data:image/png;base64,...) nếu có
    if ',' in base64_str:
        base64_str = base64_str.split(',')[1]
    image_data = base64.b64decode(base64_str)
    image = Image.open(io.BytesIO(image_data))
    image_clean = prepare_image(image)

    pixel_values = processor(image_clean, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return image_clean, generated_text


with gr.Blocks() as demo:
    gr.Markdown("## Captcha OCR Demo")

    with gr.Tab("Upload image"):
        with gr.Row():
            image_input = gr.Image(type="numpy", label="Upload Image")
            image_output = gr.Image(type="pil", label="Processed Image")
        text_output = gr.Textbox(label="OCR Output")
        image_button = gr.Button("Submit")
        image_button.click(fn=process_image, inputs=image_input, outputs=[image_output, text_output])

    with gr.Tab("Paste base64"):
        with gr.Row():
            base64_input = gr.Textbox(label="Paste base64 here", lines=5, placeholder="data:image/png;base64,...")
        with gr.Row():
            base64_output_img = gr.Image(type="pil", label="Processed Image")
            base64_output_txt = gr.Textbox(label="OCR Output")
        base64_button = gr.Button("Submit")
        base64_button.click(fn=process_base64, inputs=base64_input, outputs=[base64_output_img, base64_output_txt])

    gr.Examples(
        examples=[f"examples/captcha-{i}.png" for i in range(10)],
        inputs=image_input
    )

if __name__ == "__main__":
    demo.launch()