gabar92 commited on
Commit
738bd96
·
1 Parent(s): a643092
Files changed (1) hide show
  1. app.py +40 -107
app.py CHANGED
@@ -1,148 +1,81 @@
1
- """
2
- Small demo application to explore Gradio.
3
- """
4
-
5
- import argparse
6
  import os
7
  from functools import partial
8
-
9
  import gradio as gr
10
  from PIL import Image
11
  from huggingface_hub import hf_hub_download
12
-
13
  from die_model import UNetDIEModel
14
- from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
15
- remove_square_padding
16
-
17
 
18
- def die_inference(
19
- image_raw,
20
- num_of_die_iterations,
21
- die_model,
22
- device
23
- ):
24
  """
25
- Function to run the DIE model.
26
- :param image_raw: raw image
27
- :param num_of_die_iterations: number of DIE iterations
28
- :param die_model: DIE model
29
- :param device: device
30
- :return: cleaned image
31
  """
32
-
33
  # preprocess
34
  image_raw_resized = resize_image(image_raw, 1500)
35
  image_raw_resized_square = make_image_square(image_raw_resized)
36
- image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
37
- image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
38
-
39
  # convert string to int
40
  num_of_die_iterations = int(num_of_die_iterations)
41
-
42
  # inference
43
  image_die = die_model.enhance_document_image(
44
  image_raw_list=[image_raw_resized_square_tensor],
45
  num_of_die_iterations=num_of_die_iterations
46
  )[0]
47
-
48
  # postprocess
49
- image_die_resized = remove_square_padding(
50
  original_image=image_raw,
51
  square_image=image_die,
52
  resize_back_to_original=True
53
  )
54
 
 
 
55
 
56
- return image_die_resized
57
 
 
58
 
59
- def main():
60
- """
61
- Main function to run the Gradio demo.
62
- :return:
63
- """
64
 
65
- args = parse_arguments()
 
 
 
66
 
67
- description = """
68
- Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n
69
-
70
- This interactive application showcases a specialized AI model developed by
71
- the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n
72
-
73
- Our DIE model is designed to enhance and restore archival and aged document images
74
- by removing various types of degradation, thereby making historical documents more legible
75
- and suitable for Optical Character Recognition (OCR) processing.\n\n
76
-
77
- The model effectively tackles 20-30 types of domain-specific noise found in historical records,
78
- such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise,
79
- and unwanted background elements.
80
- By applying deep learning techniques, specifically a U-Net-based architecture,
81
- the model accurately cleans and clarifies text while preserving original details.
82
- This improved clarity dramatically boosts OCR accuracy, making it an ideal
83
- pre-processing tool in digitization workflows.\n\n
84
-
85
- If you’re interested in learning more about the model’s capabilities or potential applications,
86
- please contact us at: [email protected].\n
87
-
88
- <img src="https://huggingface.co/spaces/renyi-ai/die_demo/blob/main/logo/qr-code.png">
89
- """
90
 
91
- # TODO: Add a description for the Number of DIE iterations parameter!
 
 
 
 
92
 
93
- num_of_die_iterations_list = [1, 2, 3]
 
94
 
 
95
  die_token = os.getenv("DIE_TOKEN")
96
-
97
- # Provide images alone for example display
98
- example_image_list = [
99
- [Image.open(os.path.join(args.example_image_path, image_path))]
100
- for image_path in os.listdir(args.example_image_path)
101
- ]
102
-
103
- # Load DIE model
104
- args.die_model_path = hf_hub_download(
105
  repo_id="gabar92/die",
106
- filename=args.die_model_path,
107
  use_auth_token=die_token
108
  )
109
-
110
- die_model = UNetDIEModel(args=args)
111
-
112
- # Partially apply the model and device arguments to die_inference
113
- partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
114
-
115
- demo = gr.Interface(
116
- fn=partial_die_inference,
117
- inputs=[
118
- gr.Image(type="pil", label="Degraded Document Image"),
119
- gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
120
- ],
121
- outputs=gr.Image(type="pil", label="Clean Document Image"),
122
- title="Document Image Enhancement (DIE) model",
123
- description=description,
124
- examples=example_image_list
125
- )
126
-
127
- demo.launch(server_name="0.0.0.0", server_port=7860)
128
-
129
-
130
- def parse_arguments():
131
- """
132
- Parse arguments.
133
- :return: argument namespace
134
- """
135
-
136
- parser = argparse.ArgumentParser()
137
-
138
- parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
139
- parser.add_argument("--device", default="cpu")
140
-
141
- parser.add_argument("--example_image_path", default="example_images")
142
 
143
- return parser.parse_args()
 
144
 
 
 
145
 
146
- if __name__ == "__main__":
147
 
148
- main()
 
 
 
 
 
 
1
  import os
2
  from functools import partial
 
3
  import gradio as gr
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
 
6
  from die_model import UNetDIEModel
7
+ from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
 
 
8
 
9
+ def die_inference(image_raw, num_of_die_iterations, die_model, device):
 
 
 
 
 
10
  """
11
+ Applies the DIE model for document enhancement on a provided image.
 
 
 
 
 
12
  """
 
13
  # preprocess
14
  image_raw_resized = resize_image(image_raw, 1500)
15
  image_raw_resized_square = make_image_square(image_raw_resized)
16
+ image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device)
17
+
 
18
  # convert string to int
19
  num_of_die_iterations = int(num_of_die_iterations)
20
+
21
  # inference
22
  image_die = die_model.enhance_document_image(
23
  image_raw_list=[image_raw_resized_square_tensor],
24
  num_of_die_iterations=num_of_die_iterations
25
  )[0]
26
+
27
  # postprocess
28
+ return remove_square_padding(
29
  original_image=image_raw,
30
  square_image=image_die,
31
  resize_back_to_original=True
32
  )
33
 
34
+ description = """
35
+ Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!
36
 
37
+ This application showcases a specialized AI model by the Artificial Intelligence group at the Alfréd Rényi Institute of Mathematics, aimed at enhancing and restoring archival document images. This model removes domain-specific noise, preserving clarity and improving OCR accuracy, particularly for aged and historical documents.
38
 
39
+ Contact: [email protected]
40
 
41
+ """
 
 
 
 
42
 
43
+ with gr.Blocks() as demo:
44
+ with gr.Row():
45
+ with gr.Column():
46
+ gr.Markdown("## Document Image Enhancement (DIE) Model")
47
 
48
+ with gr.Row():
49
+ with gr.Column():
50
+ gr.Markdown(description)
51
+ with gr.Column():
52
+ # Displaying the QR code directly as an image in Gradio
53
+ gr.Image(value=Image.open("path/to/qr-code.png"), label="QR Code")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ with gr.Row():
56
+ with gr.Column():
57
+ input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
58
+ num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
59
+ run_button = gr.Button("Enhance Image")
60
 
61
+ with gr.Column():
62
+ output_image = gr.Image(type="pil", label="Enhanced Document Image")
63
 
64
+ # Load model
65
  die_token = os.getenv("DIE_TOKEN")
66
+ model_path = hf_hub_download(
 
 
 
 
 
 
 
 
67
  repo_id="gabar92/die",
68
+ filename="2024_08_09_model_epoch_89.pt",
69
  use_auth_token=die_token
70
  )
71
+ die_model = UNetDIEModel(args=model_path)
72
+ device = "cpu" # or "cuda" based on your setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Partial function for inference
75
+ partial_die_inference = partial(die_inference, die_model=die_model, device=device)
76
 
77
+ # Define button behavior
78
+ run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
79
 
80
+ demo.launch()
81