gabar92 commited on
Commit
f6be418
·
1 Parent(s): f72a688

change code back to first version

Browse files
Files changed (1) hide show
  1. app.py +64 -56
app.py CHANGED
@@ -5,16 +5,22 @@ Small demo application to explore Gradio.
5
  import argparse
6
  import os
7
  from functools import partial
8
- import torch
9
 
10
  import gradio as gr
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
 
14
  from die_model import UNetDIEModel
15
- from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
 
16
 
17
- def die_inference(image_raw, num_of_die_iterations, die_model, device):
 
 
 
 
 
 
18
  """
19
  Function to run the DIE model.
20
  :param image_raw: raw image
@@ -23,61 +29,69 @@ def die_inference(image_raw, num_of_die_iterations, die_model, device):
23
  :param device: device
24
  :return: cleaned image
25
  """
26
- # Preprocess
 
27
  image_raw_resized = resize_image(image_raw, 1500)
28
  image_raw_resized_square = make_image_square(image_raw_resized)
29
  image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
30
  image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
31
 
32
- # Convert string to int
33
  num_of_die_iterations = int(num_of_die_iterations)
34
 
35
- # Inference
36
  image_die = die_model.enhance_document_image(
37
  image_raw_list=[image_raw_resized_square_tensor],
38
  num_of_die_iterations=num_of_die_iterations
39
  )[0]
40
 
41
- # Postprocess
42
  image_die_resized = remove_square_padding(
43
  original_image=image_raw,
44
  square_image=image_die,
45
  resize_back_to_original=True
46
  )
47
 
 
48
  return image_die_resized
49
 
 
50
  def main():
51
  """
52
  Main function to run the Gradio demo.
53
- """
54
- args = parse_arguments()
55
-
56
- description_intro = """
57
- # Welcome to the Document Image Enhancement (DIE) Model Demo!
58
- This interactive application showcases the capabilities of a specialized AI model developed by the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).
59
- """
60
-
61
- description_overview = """
62
- ## Model Overview
63
- The DIE model is crafted to enhance and restore archival and aged document images by removing various types of degradation. It tackles noise such as scribbles, bleed-through text, faded or blurred text, and other unwanted background elements. This process significantly improves the clarity of documents, which in turn enhances Optical Character Recognition (OCR) accuracy.
64
  """
65
 
66
- description_features = """
67
- ## Features
68
- - Removes 20-30 types of domain-specific noise often found in historical records.
69
- - Utilizes a U-Net-based architecture for effective and detailed text restoration.
70
- - Serves as a valuable pre-processing tool in digitization workflows, especially for archives and historical documents.
71
- """
72
 
73
- description_contact = """
74
- ## Contact Us
75
- For more information, feel free to reach out at: [email protected]
76
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  num_of_die_iterations_list = [1, 2, 3]
 
79
  die_token = os.getenv("DIE_TOKEN")
80
 
 
81
  example_image_list = [
82
  [Image.open(os.path.join(args.example_image_path, image_path))]
83
  for image_path in os.listdir(args.example_image_path)
@@ -90,48 +104,42 @@ def main():
90
  use_auth_token=die_token
91
  )
92
 
93
- args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
-
95
  die_model = UNetDIEModel(args=args)
96
 
97
  # Partially apply the model and device arguments to die_inference
98
  partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
99
 
100
- # Gradio Interface
101
- with gr.Blocks() as demo:
102
- gr.Markdown(description_intro)
103
- gr.Markdown(description_overview)
104
- gr.Markdown(description_features)
105
- gr.Markdown(description_contact)
106
- gr.Image(label="", value="logo/qr-code.png", height=200, width=200)
107
-
108
- # Define inputs and outputs
109
- gr.Markdown("### Upload a degraded document image and select the number of DIE iterations:")
110
- degraded_image_input = gr.Image(type="pil", label="Degraded Document Image")
111
- iterations_input = gr.Dropdown(
112
- num_of_die_iterations_list, label="Number of DIE iterations", value=1,
113
- info="Choose the number of times to apply the enhancement model."
114
- )
115
-
116
- clean_image_output = gr.Image(type="pil", label="Clean Document Image")
117
-
118
- gr.Interface(
119
- fn=partial_die_inference,
120
- inputs=[degraded_image_input, iterations_input],
121
- outputs=clean_image_output,
122
- examples=example_image_list,
123
- title="Document Image Enhancement (DIE) Model"
124
- ).launch(server_name="0.0.0.0", server_port=7860, share=True)
125
 
126
  def parse_arguments():
127
  """
128
  Parse arguments.
 
129
  """
 
130
  parser = argparse.ArgumentParser()
 
131
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
 
 
132
  parser.add_argument("--example_image_path", default="example_images")
 
133
  return parser.parse_args()
134
 
 
135
  if __name__ == "__main__":
136
- main()
137
 
 
 
5
  import argparse
6
  import os
7
  from functools import partial
 
8
 
9
  import gradio as gr
10
  from PIL import Image
11
  from huggingface_hub import hf_hub_download
12
 
13
  from die_model import UNetDIEModel
14
+ from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
15
+ remove_square_padding
16
 
17
+
18
+ def die_inference(
19
+ image_raw,
20
+ num_of_die_iterations,
21
+ die_model,
22
+ device
23
+ ):
24
  """
25
  Function to run the DIE model.
26
  :param image_raw: raw image
 
29
  :param device: device
30
  :return: cleaned image
31
  """
32
+
33
+ # preprocess
34
  image_raw_resized = resize_image(image_raw, 1500)
35
  image_raw_resized_square = make_image_square(image_raw_resized)
36
  image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
37
  image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
38
 
39
+ # convert string to int
40
  num_of_die_iterations = int(num_of_die_iterations)
41
 
42
+ # inference
43
  image_die = die_model.enhance_document_image(
44
  image_raw_list=[image_raw_resized_square_tensor],
45
  num_of_die_iterations=num_of_die_iterations
46
  )[0]
47
 
48
+ # postprocess
49
  image_die_resized = remove_square_padding(
50
  original_image=image_raw,
51
  square_image=image_die,
52
  resize_back_to_original=True
53
  )
54
 
55
+
56
  return image_die_resized
57
 
58
+
59
  def main():
60
  """
61
  Main function to run the Gradio demo.
62
+ :return:
 
 
 
 
 
 
 
 
 
 
63
  """
64
 
65
+ args = parse_arguments()
 
 
 
 
 
66
 
67
+ description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
68
+ "" \
69
+ "This interactive application showcases a specialized AI model developed by " \
70
+ "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
71
+ "" \
72
+ "Our DIE model is designed to enhance and restore archival and aged document images " \
73
+ "by removing various types of degradation, thereby making historical documents more legible " \
74
+ "and suitable for Optical Character Recognition (OCR) processing.\n\n" \
75
+ "" \
76
+ "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
77
+ "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
78
+ "and unwanted background elements. " \
79
+ "By applying deep learning techniques, specifically a U-Net-based architecture, " \
80
+ "the model accurately cleans and clarifies text while preserving original details. " \
81
+ "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
82
+ "pre-processing tool in digitization workflows.\n\n" \
83
+ "" \
84
+ "If you’re interested in learning more about the model’s capabilities or potential applications, " \
85
+ "please contact us at: [email protected].\n" \
86
+ "<img src='logo/qr-code.png' width=200px>\n\n"
87
+
88
+ # TODO: Add a description for the Number of DIE iterations parameter!
89
 
90
  num_of_die_iterations_list = [1, 2, 3]
91
+
92
  die_token = os.getenv("DIE_TOKEN")
93
 
94
+ # Provide images alone for example display
95
  example_image_list = [
96
  [Image.open(os.path.join(args.example_image_path, image_path))]
97
  for image_path in os.listdir(args.example_image_path)
 
104
  use_auth_token=die_token
105
  )
106
 
 
 
107
  die_model = UNetDIEModel(args=args)
108
 
109
  # Partially apply the model and device arguments to die_inference
110
  partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
111
 
112
+ demo = gr.Interface(
113
+ fn=partial_die_inference,
114
+ inputs=[
115
+ gr.Image(type="pil", label="Degraded Document Image"),
116
+ gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
117
+ ],
118
+ outputs=gr.Image(type="pil", label="Clean Document Image"),
119
+ title="Document Image Enhancement (DIE) model",
120
+ description=description,
121
+ examples=example_image_list
122
+ )
123
+
124
+ demo.launch(server_name="0.0.0.0", server_port=7860)
125
+
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def parse_arguments():
128
  """
129
  Parse arguments.
130
+ :return: argument namespace
131
  """
132
+
133
  parser = argparse.ArgumentParser()
134
+
135
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
136
+ parser.add_argument("--device", default="cpu")
137
+
138
  parser.add_argument("--example_image_path", default="example_images")
139
+
140
  return parser.parse_args()
141
 
142
+
143
  if __name__ == "__main__":
 
144
 
145
+ main()