gabar92 commited on
Commit
c063bb0
·
1 Parent(s): 7c94ac5
Files changed (1) hide show
  1. app.py +68 -35
app.py CHANGED
@@ -1,11 +1,15 @@
 
1
  import os
2
  from functools import partial
 
3
  import gradio as gr
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
 
6
  from die_model import UNetDIEModel
7
  from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
8
 
 
9
  def die_inference(image_raw, num_of_die_iterations, die_model, device):
10
  """
11
  Applies the DIE model for document enhancement on a provided image.
@@ -31,51 +35,80 @@ def die_inference(image_raw, num_of_die_iterations, die_model, device):
31
  resize_back_to_original=True
32
  )
33
 
34
- description = """
35
- Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!
36
-
37
- This application showcases a specialized AI model by the Artificial Intelligence group at the Alfréd Rényi Institute of Mathematics, aimed at enhancing and restoring archival document images. This model removes domain-specific noise, preserving clarity and improving OCR accuracy, particularly for aged and historical documents.
38
-
39
- Contact: [email protected]
40
-
41
- """
42
-
43
- with gr.Blocks() as demo:
44
- with gr.Row():
45
- with gr.Column():
46
- gr.Markdown("## Document Image Enhancement (DIE) Model")
47
 
48
- with gr.Row():
49
- with gr.Column():
50
- gr.Markdown(description)
51
- with gr.Column():
52
- # Displaying the QR code directly as an image in Gradio
53
- gr.Image(value=Image.open("logo/qr-code.png"), label="QR Code")
54
-
55
- with gr.Row():
56
- with gr.Column():
57
- input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
58
- num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
59
- run_button = gr.Button("Enhance Image")
60
-
61
- with gr.Column():
62
- output_image = gr.Image(type="pil", label="Enhanced Document Image")
63
 
64
- # Load model
65
  die_token = os.getenv("DIE_TOKEN")
66
  model_path = hf_hub_download(
67
  repo_id="gabar92/die",
68
- filename="2024_08_09_model_epoch_89.pt",
69
  use_auth_token=die_token
70
  )
71
  die_model = UNetDIEModel(args=model_path)
72
- device = "cpu" # or "cuda" based on your setup
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Partial function for inference
 
 
 
 
75
  partial_die_inference = partial(die_inference, die_model=die_model, device=device)
76
 
77
- # Define button behavior
78
- run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- demo.launch()
 
81
 
 
1
+ import argparse
2
  import os
3
  from functools import partial
4
+
5
  import gradio as gr
6
  from PIL import Image
7
  from huggingface_hub import hf_hub_download
8
+
9
  from die_model import UNetDIEModel
10
  from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
11
 
12
+
13
  def die_inference(image_raw, num_of_die_iterations, die_model, device):
14
  """
15
  Applies the DIE model for document enhancement on a provided image.
 
35
  resize_back_to_original=True
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def main():
40
+ """
41
+ Main function to set up and run the Gradio demo.
42
+ """
43
+ args = parse_arguments()
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Set up model
46
  die_token = os.getenv("DIE_TOKEN")
47
  model_path = hf_hub_download(
48
  repo_id="gabar92/die",
49
+ filename=args.die_model_path,
50
  use_auth_token=die_token
51
  )
52
  die_model = UNetDIEModel(args=model_path)
53
+ device = args.device
54
+
55
+ # Prepare example images
56
+ example_image_list = [
57
+ [Image.open(os.path.join(args.example_image_path, image_path))]
58
+ for image_path in os.listdir(args.example_image_path)
59
+ ]
60
+
61
+ description = """
62
+ Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!
63
+
64
+ This application showcases a specialized AI model by the Artificial Intelligence group at the Alfréd Rényi Institute of Mathematics, aimed at enhancing and restoring archival document images. This model removes domain-specific noise, preserving clarity and improving OCR accuracy, particularly for aged and historical documents.
65
 
66
+ Contact: [email protected]
67
+
68
+ """
69
+
70
+ # Partial function for inference with model and device arguments
71
  partial_die_inference = partial(die_inference, die_model=die_model, device=device)
72
 
73
+ with gr.Blocks() as demo:
74
+ with gr.Row():
75
+ with gr.Column():
76
+ gr.Markdown("## Document Image Enhancement (DIE) Model")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ gr.Markdown(description)
81
+ with gr.Column():
82
+ # Display QR code as an image in Gradio
83
+ gr.Image(value=Image.open("path/to/qr-code.png"), label="QR Code")
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
88
+ num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
89
+ run_button = gr.Button("Enhance Image")
90
+
91
+ with gr.Column():
92
+ output_image = gr.Image(type="pil", label="Enhanced Document Image")
93
+
94
+ # Button trigger for inference
95
+ run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
96
+
97
+ demo.launch()
98
+
99
+
100
+ def parse_arguments():
101
+ """
102
+ Parses command-line arguments.
103
+ :return: argument namespace
104
+ """
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
107
+ parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Device to run the model on")
108
+ parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
109
+ return parser.parse_args()
110
+
111
 
112
+ if __name__ == "__main__":
113
+ main()
114