File size: 2,966 Bytes
a9d81c5 698149b a9d81c5 738bd96 f6be418 738bd96 a9d81c5 738bd96 a9d81c5 f6be418 a9d81c5 738bd96 f6be418 a9d81c5 738bd96 f6be418 a9d81c5 738bd96 f6be418 738bd96 a9d81c5 738bd96 f6be418 738bd96 a9d81c5 738bd96 f6be418 738bd96 a9d81c5 738bd96 e06502f 738bd96 f6be418 738bd96 a9d81c5 738bd96 f6be418 738bd96 7ea1aa1 738bd96 04f8aab 738bd96 04f8aab 738bd96 f6be418 738bd96 a9d81c5 738bd96 f6be418 738bd96 e06502f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
from functools import partial
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
def die_inference(image_raw, num_of_die_iterations, die_model, device):
"""
Applies the DIE model for document enhancement on a provided image.
"""
# preprocess
image_raw_resized = resize_image(image_raw, 1500)
image_raw_resized_square = make_image_square(image_raw_resized)
image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device)
# convert string to int
num_of_die_iterations = int(num_of_die_iterations)
# inference
image_die = die_model.enhance_document_image(
image_raw_list=[image_raw_resized_square_tensor],
num_of_die_iterations=num_of_die_iterations
)[0]
# postprocess
return remove_square_padding(
original_image=image_raw,
square_image=image_die,
resize_back_to_original=True
)
description = """
Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!
This application showcases a specialized AI model by the Artificial Intelligence group at the Alfréd Rényi Institute of Mathematics, aimed at enhancing and restoring archival document images. This model removes domain-specific noise, preserving clarity and improving OCR accuracy, particularly for aged and historical documents.
Contact: [email protected]
"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Document Image Enhancement (DIE) Model")
with gr.Row():
with gr.Column():
gr.Markdown(description)
with gr.Column():
# Displaying the QR code directly as an image in Gradio
gr.Image(value=Image.open("path/to/qr-code.png"), label="QR Code")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
run_button = gr.Button("Enhance Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Enhanced Document Image")
# Load model
die_token = os.getenv("DIE_TOKEN")
model_path = hf_hub_download(
repo_id="gabar92/die",
filename="2024_08_09_model_epoch_89.pt",
use_auth_token=die_token
)
die_model = UNetDIEModel(args=model_path)
device = "cpu" # or "cuda" based on your setup
# Partial function for inference
partial_die_inference = partial(die_inference, die_model=die_model, device=device)
# Define button behavior
run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
demo.launch()
|