die_demo / app.py
gabar92's picture
refactor
738bd96
raw
history blame
2.97 kB
import os
from functools import partial
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
def die_inference(image_raw, num_of_die_iterations, die_model, device):
"""
Applies the DIE model for document enhancement on a provided image.
"""
# preprocess
image_raw_resized = resize_image(image_raw, 1500)
image_raw_resized_square = make_image_square(image_raw_resized)
image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device)
# convert string to int
num_of_die_iterations = int(num_of_die_iterations)
# inference
image_die = die_model.enhance_document_image(
image_raw_list=[image_raw_resized_square_tensor],
num_of_die_iterations=num_of_die_iterations
)[0]
# postprocess
return remove_square_padding(
original_image=image_raw,
square_image=image_die,
resize_back_to_original=True
)
description = """
Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!
This application showcases a specialized AI model by the Artificial Intelligence group at the Alfréd Rényi Institute of Mathematics, aimed at enhancing and restoring archival document images. This model removes domain-specific noise, preserving clarity and improving OCR accuracy, particularly for aged and historical documents.
Contact: [email protected]
"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Document Image Enhancement (DIE) Model")
with gr.Row():
with gr.Column():
gr.Markdown(description)
with gr.Column():
# Displaying the QR code directly as an image in Gradio
gr.Image(value=Image.open("path/to/qr-code.png"), label="QR Code")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
run_button = gr.Button("Enhance Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Enhanced Document Image")
# Load model
die_token = os.getenv("DIE_TOKEN")
model_path = hf_hub_download(
repo_id="gabar92/die",
filename="2024_08_09_model_epoch_89.pt",
use_auth_token=die_token
)
die_model = UNetDIEModel(args=model_path)
device = "cpu" # or "cuda" based on your setup
# Partial function for inference
partial_die_inference = partial(die_inference, die_model=die_model, device=device)
# Define button behavior
run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
demo.launch()