gabar92 commited on
Commit
e06502f
·
1 Parent(s): 5a6cddb

refactor app.yp

Browse files
Files changed (1) hide show
  1. app.py +53 -64
app.py CHANGED
@@ -11,16 +11,9 @@ from PIL import Image
11
  from huggingface_hub import hf_hub_download
12
 
13
  from die_model import UNetDIEModel
14
- from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
15
- remove_square_padding
16
 
17
-
18
- def die_inference(
19
- image_raw,
20
- num_of_die_iterations,
21
- die_model,
22
- device
23
- ):
24
  """
25
  Function to run the DIE model.
26
  :param image_raw: raw image
@@ -29,69 +22,61 @@ def die_inference(
29
  :param device: device
30
  :return: cleaned image
31
  """
32
-
33
- # preprocess
34
  image_raw_resized = resize_image(image_raw, 1500)
35
  image_raw_resized_square = make_image_square(image_raw_resized)
36
  image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
37
  image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
38
 
39
- # convert string to int
40
  num_of_die_iterations = int(num_of_die_iterations)
41
 
42
- # inference
43
  image_die = die_model.enhance_document_image(
44
  image_raw_list=[image_raw_resized_square_tensor],
45
  num_of_die_iterations=num_of_die_iterations
46
  )[0]
47
 
48
- # postprocess
49
  image_die_resized = remove_square_padding(
50
  original_image=image_raw,
51
  square_image=image_die,
52
  resize_back_to_original=True
53
  )
54
 
55
-
56
  return image_die_resized
57
 
58
-
59
  def main():
60
  """
61
  Main function to run the Gradio demo.
62
- :return:
63
  """
64
-
65
  args = parse_arguments()
66
 
67
- description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
68
- "" \
69
- "This interactive application showcases a specialized AI model developed by " \
70
- "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
71
- "" \
72
- "Our DIE model is designed to enhance and restore archival and aged document images " \
73
- "by removing various types of degradation, thereby making historical documents more legible " \
74
- "and suitable for Optical Character Recognition (OCR) processing.\n\n" \
75
- "" \
76
- "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
77
- "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
78
- "and unwanted background elements. " \
79
- "By applying deep learning techniques, specifically a U-Net-based architecture, " \
80
- "the model accurately cleans and clarifies text while preserving original details. " \
81
- "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
82
- "pre-processing tool in digitization workflows.\n\n" \
83
- "" \
84
- "If you’re interested in learning more about the model’s capabilities or potential applications, " \
85
- "please contact us at: [email protected].\n" \
86
- "<img src='logo/qr-code.png' width=200px>\n\n"
87
-
88
- # TODO: Add a description for the Number of DIE iterations parameter!
89
 
90
- num_of_die_iterations_list = [1, 2, 3]
 
 
 
 
 
 
 
 
 
 
91
 
 
92
  die_token = os.getenv("DIE_TOKEN")
93
 
94
- # Provide images alone for example display
95
  example_image_list = [
96
  [Image.open(os.path.join(args.example_image_path, image_path))]
97
  for image_path in os.listdir(args.example_image_path)
@@ -109,37 +94,41 @@ def main():
109
  # Partially apply the model and device arguments to die_inference
110
  partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
111
 
112
- demo = gr.Interface(
113
- fn=partial_die_inference,
114
- inputs=[
115
- gr.Image(type="pil", label="Degraded Document Image"),
116
- gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
117
- ],
118
- outputs=gr.Image(type="pil", label="Clean Document Image"),
119
- title="Document Image Enhancement (DIE) model",
120
- description=description,
121
- examples=example_image_list
122
- )
123
-
124
- demo.launch(server_name="0.0.0.0", server_port=7860)
125
-
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def parse_arguments():
128
  """
129
  Parse arguments.
130
- :return: argument namespace
131
  """
132
-
133
  parser = argparse.ArgumentParser()
134
-
135
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
136
- parser.add_argument("--device", default="cpu")
137
-
138
  parser.add_argument("--example_image_path", default="example_images")
139
-
140
  return parser.parse_args()
141
 
142
-
143
  if __name__ == "__main__":
144
-
145
  main()
 
 
11
  from huggingface_hub import hf_hub_download
12
 
13
  from die_model import UNetDIEModel
14
+ from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
 
15
 
16
+ def die_inference(image_raw, num_of_die_iterations, die_model, device):
 
 
 
 
 
 
17
  """
18
  Function to run the DIE model.
19
  :param image_raw: raw image
 
22
  :param device: device
23
  :return: cleaned image
24
  """
25
+ # Preprocess
 
26
  image_raw_resized = resize_image(image_raw, 1500)
27
  image_raw_resized_square = make_image_square(image_raw_resized)
28
  image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
29
  image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
30
 
31
+ # Convert string to int
32
  num_of_die_iterations = int(num_of_die_iterations)
33
 
34
+ # Inference
35
  image_die = die_model.enhance_document_image(
36
  image_raw_list=[image_raw_resized_square_tensor],
37
  num_of_die_iterations=num_of_die_iterations
38
  )[0]
39
 
40
+ # Postprocess
41
  image_die_resized = remove_square_padding(
42
  original_image=image_raw,
43
  square_image=image_die,
44
  resize_back_to_original=True
45
  )
46
 
 
47
  return image_die_resized
48
 
 
49
  def main():
50
  """
51
  Main function to run the Gradio demo.
 
52
  """
 
53
  args = parse_arguments()
54
 
55
+ description_intro = """
56
+ # Welcome to the Document Image Enhancement (DIE) Model Demo!
57
+ This interactive application showcases the capabilities of a specialized AI model developed by the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).
58
+ """
59
+
60
+ description_overview = """
61
+ ## Model Overview
62
+ The DIE model is crafted to enhance and restore archival and aged document images by removing various types of degradation. It tackles noise such as scribbles, bleed-through text, faded or blurred text, and other unwanted background elements. This process significantly improves the clarity of documents, which in turn enhances Optical Character Recognition (OCR) accuracy.
63
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ description_features = """
66
+ ## Features
67
+ - Removes 20-30 types of domain-specific noise often found in historical records.
68
+ - Utilizes a U-Net-based architecture for effective and detailed text restoration.
69
+ - Serves as a valuable pre-processing tool in digitization workflows, especially for archives and historical documents.
70
+ """
71
+
72
+ description_contact = """
73
+ ## Contact Us
74
+ For more information, feel free to reach out at: [email protected]
75
+ """
76
 
77
+ num_of_die_iterations_list = [1, 2, 3]
78
  die_token = os.getenv("DIE_TOKEN")
79
 
 
80
  example_image_list = [
81
  [Image.open(os.path.join(args.example_image_path, image_path))]
82
  for image_path in os.listdir(args.example_image_path)
 
94
  # Partially apply the model and device arguments to die_inference
95
  partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
96
 
97
+ # Gradio Interface
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown(description_intro)
100
+ gr.Markdown(description_overview)
101
+ gr.Markdown(description_features)
102
+ gr.Markdown(description_contact)
103
+ gr.Image(label="", value="logo/qr-code.png", height=200, width=200)
104
+
105
+ # Define inputs and outputs
106
+ gr.Markdown("### Upload a degraded document image and select the number of DIE iterations:")
107
+ degraded_image_input = gr.Image(type="pil", label="Degraded Document Image")
108
+ iterations_input = gr.Dropdown(
109
+ num_of_die_iterations_list, label="Number of DIE iterations", value=1,
110
+ info="Choose the number of times to apply the enhancement model."
111
+ )
112
+
113
+ clean_image_output = gr.Image(type="pil", label="Clean Document Image")
114
+
115
+ gr.Interface(
116
+ fn=partial_die_inference,
117
+ inputs=[degraded_image_input, iterations_input],
118
+ outputs=clean_image_output,
119
+ examples=example_image_list,
120
+ title="Document Image Enhancement (DIE) Model"
121
+ ).launch(server_name="0.0.0.0", server_port=7860)
122
 
123
  def parse_arguments():
124
  """
125
  Parse arguments.
 
126
  """
 
127
  parser = argparse.ArgumentParser()
 
128
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
 
 
129
  parser.add_argument("--example_image_path", default="example_images")
 
130
  return parser.parse_args()
131
 
 
132
  if __name__ == "__main__":
 
133
  main()
134
+