Sqxww commited on
Commit
1da9e98
·
1 Parent(s): 7a6754c

add brightness

Browse files
Files changed (2) hide show
  1. app_base.py +28 -87
  2. segment_utils.py +23 -1
app_base.py CHANGED
@@ -2,15 +2,13 @@ import spaces
2
  import gradio as gr
3
  import time
4
  import torch
5
- import tempfile
6
  import os
7
  import gc
8
 
9
- from loading_utils import load_image
10
-
11
  from segment_utils import(
12
  segment_image,
13
- restore_result,
14
  )
15
  from enhance_utils import enhance_sd_image
16
  from inversion_run_base import run as base_run
@@ -20,8 +18,11 @@ DEFAULT_EDIT_PROMPT = "a person with perfect face"
20
 
21
  DEFAULT_CATEGORY = "face"
22
 
 
 
 
23
  def image_to_image(
24
- input_image_path: str,
25
  input_image_prompt: str,
26
  edit_prompt: str,
27
  seed: int,
@@ -29,35 +30,14 @@ def image_to_image(
29
  num_steps: int,
30
  start_step: int,
31
  guidance_scale: float,
32
- generate_size: int,
33
- mask_expansion: int = 50,
34
- mask_dilation: int = 2,
35
- save_quality: int = 95,
36
- enable_segment: bool = True,
37
  ):
38
- segment_category = "face"
39
  w2 = 1.0
40
  run_task_time = 0
41
  time_cost_str = ''
42
 
43
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
44
- input_image = load_image(input_image_path)
45
- icc_profile = input_image.info.get('icc_profile')
46
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'load_image done')
47
-
48
- if enable_segment:
49
- target_area_image, croper = segment_image(
50
- input_image,
51
- segment_category,
52
- generate_size,
53
- mask_expansion,
54
- mask_dilation,
55
- )
56
- else:
57
- target_area_image = resize_image(input_image, generate_size)
58
- croper = None
59
-
60
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'segment_image done')
61
 
62
  run_model = base_run
63
  try:
@@ -82,30 +62,16 @@ def image_to_image(
82
  enhanced_image = enhance_sd_image(res_image)
83
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'enhance_image done')
84
 
85
- if enable_segment:
86
- restored_image = restore_result(croper, segment_category, enhanced_image)
87
- else:
88
- restored_image = enhanced_image.resize(input_image.size)
89
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'restore_result done')
90
-
91
  torch.cuda.empty_cache()
92
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
93
  if os.getenv('ENABLE_GC', False):
94
  gc.collect()
95
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'gc_collect done')
96
 
97
- extension = 'png'
98
- if restored_image.mode == 'RGBA':
99
- extension = 'png'
100
- else:
101
- extension = 'webp'
102
-
103
- output_path = tempfile.mktemp(suffix=f".{extension}")
104
- restored_image.save(output_path, format=extension, quality=save_quality, icc_profile=icc_profile)
105
-
106
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'save_image done')
107
 
108
- return output_path, restored_image, time_cost_str
109
 
110
  def get_time_cost(
111
  run_task_time,
@@ -134,49 +100,16 @@ def resize_image(image, target_size = 1024):
134
  w = target_size
135
  return image.resize((w, h))
136
 
137
-
138
- def infer(
139
- input_image_path: str,
140
- input_image_prompt: str,
141
- edit_prompt: str,
142
- seed: int,
143
- w1: float,
144
- num_steps: int,
145
- start_step: int,
146
- guidance_scale: float,
147
- generate_size: int,
148
- mask_expansion: int = 50,
149
- mask_dilation: int = 2,
150
- save_quality: int = 95,
151
- enable_segment: bool = True,
152
- ):
153
- return image_to_image(
154
- input_image_path,
155
- input_image_prompt,
156
- edit_prompt,
157
- seed,
158
- w1,
159
- num_steps,
160
- start_step,
161
- guidance_scale,
162
- generate_size,
163
- mask_expansion,
164
- mask_dilation,
165
- save_quality,
166
- enable_segment
167
- )
168
-
169
- infer = spaces.GPU(infer)
170
-
171
  def create_demo() -> gr.Blocks:
172
 
173
  with gr.Blocks() as demo:
 
174
  with gr.Row():
175
  with gr.Column():
176
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
177
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
 
178
  with gr.Accordion("Advanced Options", open=False):
179
- enable_segment = gr.Checkbox(label="Enable Segment", value=True)
180
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
181
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
182
  save_quality = gr.Slider(minimum=1, maximum=100, value=95, step=1, label="Save Quality")
@@ -192,18 +125,26 @@ def create_demo() -> gr.Blocks:
192
 
193
  with gr.Row():
194
  with gr.Column():
195
- input_image_path = gr.Image(label="Input Image", type="filepath", interactive=True)
 
196
  with gr.Column():
197
  restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
198
  download_path = gr.File(label="Download the output image", interactive=False)
 
199
  generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
200
-
201
  g_btn.click(
202
- fn=infer,
203
- inputs=[input_image_path, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, mask_expansion, mask_dilation, save_quality, enable_segment],
204
- outputs=[download_path, restored_image, generated_cost],
 
 
 
 
 
 
 
 
205
  )
206
-
207
-
208
 
209
  return demo
 
2
  import gradio as gr
3
  import time
4
  import torch
 
5
  import os
6
  import gc
7
 
8
+ from PIL import Image, ImageEnhance
 
9
  from segment_utils import(
10
  segment_image,
11
+ restore_result_and_save,
12
  )
13
  from enhance_utils import enhance_sd_image
14
  from inversion_run_base import run as base_run
 
18
 
19
  DEFAULT_CATEGORY = "face"
20
 
21
+ @spaces.GPU(duration=10)
22
+ @torch.inference_mode()
23
+ @torch.no_grad()
24
  def image_to_image(
25
+ input_image: Image,
26
  input_image_prompt: str,
27
  edit_prompt: str,
28
  seed: int,
 
30
  num_steps: int,
31
  start_step: int,
32
  guidance_scale: float,
33
+ brightness: float = 1.0,
 
 
 
 
34
  ):
 
35
  w2 = 1.0
36
  run_task_time = 0
37
  time_cost_str = ''
38
 
39
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
40
+ target_area_image = input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  run_model = base_run
43
  try:
 
62
  enhanced_image = enhance_sd_image(res_image)
63
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'enhance_image done')
64
 
 
 
 
 
 
 
65
  torch.cuda.empty_cache()
66
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
67
  if os.getenv('ENABLE_GC', False):
68
  gc.collect()
69
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'gc_collect done')
70
 
71
+ enhancer = ImageEnhance.Brightness(enhanced_image)
72
+ enhanced_image = enhancer.enhance(brightness)
 
 
 
 
 
 
 
 
73
 
74
+ return enhanced_image, time_cost_str
75
 
76
  def get_time_cost(
77
  run_task_time,
 
100
  w = target_size
101
  return image.resize((w, h))
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def create_demo() -> gr.Blocks:
104
 
105
  with gr.Blocks() as demo:
106
+ cropper = gr.State()
107
  with gr.Row():
108
  with gr.Column():
109
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
110
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
111
+ brightness = gr.Slider(minimum=0, maximum=2, value=1.0, step=0.1, label="Brightness")
112
  with gr.Accordion("Advanced Options", open=False):
 
113
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
114
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
115
  save_quality = gr.Slider(minimum=1, maximum=100, value=95, step=1, label="Save Quality")
 
125
 
126
  with gr.Row():
127
  with gr.Column():
128
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
129
+ origin_area_image = gr.Image(label="Origin Area Image", format="png", type="pil", interactive=False)
130
  with gr.Column():
131
  restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
132
  download_path = gr.File(label="Download the output image", interactive=False)
133
+ enhanced_image = gr.Image(label="Enhanced Image", format="png", type="pil", interactive=False)
134
  generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
135
+
136
  g_btn.click(
137
+ fn=segment_image,
138
+ inputs=[input_image, DEFAULT_CATEGORY, generate_size, mask_expansion, mask_dilation],
139
+ outputs=[origin_area_image, cropper],
140
+ ).success(
141
+ fn=image_to_image,
142
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale],
143
+ outputs=[enhanced_image, generated_cost],
144
+ ).success(
145
+ fn=restore_result_and_save,
146
+ inputs=[cropper, DEFAULT_CATEGORY, enhanced_image, save_quality],
147
+ outputs=[restored_image, download_path],
148
  )
 
 
149
 
150
  return demo
segment_utils.py CHANGED
@@ -1,6 +1,6 @@
1
  import numpy as np
2
  import mediapipe as mp
3
- import uuid
4
 
5
  from PIL import Image
6
  from scipy.ndimage import binary_dilation
@@ -22,6 +22,28 @@ def restore_result(croper, category, generated_image):
22
 
23
  return restored_image
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
26
  mask_size = int(input_size)
27
  mask_expansion = int(mask_expansion)
 
1
  import numpy as np
2
  import mediapipe as mp
3
+ import tempfile
4
 
5
  from PIL import Image
6
  from scipy.ndimage import binary_dilation
 
22
 
23
  return restored_image
24
 
25
+ def restore_result_and_save(croper, category, generated_image,save_quality=95):
26
+ square_length = croper.square_length
27
+ generated_image = generated_image.resize((square_length, square_length))
28
+
29
+ cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
30
+ cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
31
+
32
+ restored_image = croper.input_image.copy()
33
+ restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
34
+
35
+ extension = 'png'
36
+ if restored_image.mode == 'RGBA':
37
+ extension = 'png'
38
+ else:
39
+ extension = 'webp'
40
+
41
+ icc_profile = croper.input_image.info.get('icc_profile')
42
+ output_path = tempfile.mktemp(suffix=f".{extension}")
43
+ restored_image.save(output_path, format=extension, quality=save_quality, icc_profile=icc_profile)
44
+
45
+ return restored_image, output_path
46
+
47
  def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
48
  mask_size = int(input_size)
49
  mask_expansion = int(mask_expansion)