tuan2308 commited on
Commit
052a51e
·
verified ·
1 Parent(s): 06a3e99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -135
app.py CHANGED
@@ -1,68 +1,149 @@
1
- import spaces
2
- import gradio as gr
3
- import cv2
4
- import numpy
5
  import os
6
  import random
 
 
 
 
 
7
  from basicsr.archs.rrdbnet_arch import RRDBNet
8
  from basicsr.utils.download_util import load_file_from_url
9
-
10
  from realesrgan import RealESRGANer
11
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
 
13
-
 
 
14
  last_file = None
15
  img_mode = "RGBA"
16
 
 
 
 
 
 
 
 
 
 
 
17
  @spaces.GPU
18
- def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
19
- """Real-ESRGAN function to restore (and upscale) images.
20
  """
21
- if not img:
22
- return
 
 
 
23
 
24
- # Define model parameters
25
- if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
26
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27
- netscale = 4
28
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
29
- elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
31
  netscale = 4
32
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
33
- elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
 
 
34
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
35
  netscale = 4
36
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
37
- elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
38
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
39
  netscale = 2
40
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
41
- elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
  netscale = 4
44
  file_url = [
45
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
46
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47
  ]
 
 
48
 
49
- # Determine model paths
50
  model_path = os.path.join('weights', model_name + '.pth')
51
  if not os.path.isfile(model_path):
52
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
53
  for url in file_url:
54
- # model_path will be updated
55
- model_path = load_file_from_url(
56
- url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
57
 
58
- # Use dni to control the denoise strength
59
  dni_weight = None
60
  if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
61
  wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
62
  model_path = [model_path, wdn_model_path]
63
  dni_weight = [denoise_strength, 1 - denoise_strength]
64
 
65
- # Restorer Class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  upsampler = RealESRGANer(
67
  scale=netscale,
68
  model_path=model_path,
@@ -71,127 +152,87 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
71
  tile=0,
72
  tile_pad=10,
73
  pre_pad=10,
74
- half=False,
75
- gpu_id=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Use GFPGAN for face enhancement
79
- if face_enhance:
80
- from gfpgan import GFPGANer
81
- face_enhancer = GFPGANer(
82
- model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
83
- upscale=outscale,
84
- arch='clean',
85
- channel_multiplier=2,
86
- bg_upsampler=upsampler)
87
-
88
- # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
89
  cv_img = numpy.array(img)
90
- img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
91
 
92
- # Apply restoration
93
  try:
94
  if face_enhance:
95
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
 
 
 
96
  else:
97
- output, _ = upsampler.enhance(img, outscale=outscale)
98
  except RuntimeError as error:
 
99
  print('Error', error)
100
- print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
101
  else:
102
- # Save restored image and return it to the output Image component
103
- if img_mode == 'RGBA': # RGBA images should be saved in png format
104
- extension = 'png'
105
- else:
106
- extension = 'jpg'
107
-
108
  out_filename = f"output_{rnd_string(8)}.{extension}"
109
  cv2.imwrite(out_filename, output)
110
  global last_file
111
  last_file = out_filename
112
  return out_filename
113
 
114
-
115
- def rnd_string(x):
116
- """Returns a string of 'x' random characters
117
- """
118
- characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
119
- result = "".join((random.choice(characters)) for i in range(x))
120
- return result
121
-
122
-
123
- def reset():
124
- """Resets the Image components of the Gradio interface and deletes
125
- the last processed image
126
- """
127
- global last_file
128
- if last_file:
129
- print(f"Deleting {last_file} ...")
130
- os.remove(last_file)
131
- last_file = None
132
- return gr.update(value=None), gr.update(value=None)
133
-
134
-
135
- def has_transparency(img):
136
- """This function works by first checking to see if a "transparency" property is defined
137
- in the image's info -- if so, we return "True". Then, if the image is using indexed colors
138
- (such as in GIFs), it gets the index of the transparent color in the palette
139
- (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
140
- (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
141
- it, but it double-checks by getting the minimum and maximum values of every color channel
142
- (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
143
- https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
144
- """
145
- if img.info.get("transparency", None) is not None:
146
- return True
147
- if img.mode == "P":
148
- transparent = img.info.get("transparency", -1)
149
- for _, index in img.getcolors():
150
- if index == transparent:
151
- return True
152
- elif img.mode == "RGBA":
153
- extrema = img.getextrema()
154
- if extrema[3][0] < 255:
155
- return True
156
- return False
157
-
158
-
159
- def image_properties(img):
160
- """Returns the dimensions (width and height) and color mode of the input image and
161
- also sets the global img_mode variable to be used by the realesrgan function
162
- """
163
- global img_mode
164
- if img:
165
- if has_transparency(img):
166
- img_mode = "RGBA"
167
- else:
168
- img_mode = "RGB"
169
- properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
170
- return properties
171
-
172
-
173
  def main():
174
- # Gradio Interface
175
  with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
176
-
177
- gr.Markdown(
178
- """ Image Upscaler
179
- """
180
- )
181
 
182
  with gr.Accordion("Upscaling option"):
183
  with gr.Row():
184
- model_name = gr.Dropdown(label="Upscaler model",
185
- choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
186
- "RealESRGAN_x2plus", "realesr-general-x4v3"],
187
- value="RealESRGAN_x4plus_anime_6B", show_label=True)
188
- denoise_strength = gr.Slider(label="Denoise Strength",
189
- minimum=0, maximum=1, step=0.1, value=0.5)
190
- outscale = gr.Slider(label="Resolution upscale",
191
- minimum=1, maximum=6, step=1, value=4, show_label=True)
192
- face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)",
 
 
193
  )
194
-
 
 
 
195
  with gr.Row():
196
  with gr.Group():
197
  input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
@@ -201,22 +242,15 @@ def main():
201
  reset_btn = gr.Button("Remove images")
202
  restore_btn = gr.Button("Upscale")
203
 
204
- # Event listeners:
205
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
206
  restore_btn.click(fn=realesrgan,
207
  inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
208
  outputs=output_image)
209
  reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
210
- # reset_btn.click(None, inputs=[], outputs=[input_image], _js="() => (null)\n")
211
- # Undocumented method to clear a component's value using Javascript
212
-
213
- gr.Markdown(
214
- """
215
- """
216
- )
217
-
218
- demo.launch()
219
 
 
220
 
221
  if __name__ == "__main__":
 
 
222
  main()
 
 
 
 
 
1
  import os
2
  import random
3
+ import cv2
4
+ import numpy
5
+ import gradio as gr
6
+ import spaces
7
+
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
  from basicsr.utils.download_util import load_file_from_url
 
10
  from realesrgan import RealESRGANer
11
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
 
13
+ # --------------------
14
+ # Global (CPU-only data; KHÔNG chạm CUDA ở đây)
15
+ # --------------------
16
  last_file = None
17
  img_mode = "RGBA"
18
 
19
+ DEVICE = "cpu" # set trong gpu_startup()
20
+ USE_HALF = False # set trong gpu_startup()
21
+
22
+ # cache cho các upsampler đã khởi tạo
23
+ UPSAMPLER_CACHE = {} # key: (model_name, denoise_strength, DEVICE, USE_HALF)
24
+ GFPGAN_FACE_ENHANCER = {} # key: (outscale, DEVICE, USE_HALF)
25
+
26
+ # --------------------
27
+ # ZeroGPU: cấp GPU ngay khi khởi động
28
+ # --------------------
29
  @spaces.GPU
30
+ def gpu_startup():
 
31
  """
32
+ Hàm này chạy ngay khi Space bật trên ZeroGPU.
33
+ Chỉ ở đây mới 'đụng' tới torch/cuda.
34
+ """
35
+ global DEVICE, USE_HALF
36
+ import torch
37
 
38
+ has_cuda = torch.cuda.is_available()
39
+ DEVICE = "cuda" if has_cuda else "cpu"
40
+ # half precision chỉ an toàn khi có CUDA
41
+ USE_HALF = bool(has_cuda)
42
+
43
+ print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}")
44
+
45
+ # --------------------
46
+ # Utils
47
+ # --------------------
48
+ def rnd_string(x):
49
+ chars = "abcdefghijklmnopqrstuvwxyz_0123456789"
50
+ return "".join(random.choice(chars) for _ in range(x))
51
+
52
+ def has_transparency(img):
53
+ if img.info.get("transparency", None) is not None:
54
+ return True
55
+ if img.mode == "P":
56
+ transparent = img.info.get("transparency", -1)
57
+ for _, index in img.getcolors():
58
+ if index == transparent:
59
+ return True
60
+ elif img.mode == "RGBA":
61
+ extrema = img.getextrema()
62
+ if extrema[3][0] < 255:
63
+ return True
64
+ return False
65
+
66
+ def image_properties(img):
67
+ global img_mode
68
+ if img:
69
+ if has_transparency(img):
70
+ img_mode = "RGBA"
71
+ else:
72
+ img_mode = "RGB"
73
+ return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
74
+
75
+ def reset():
76
+ global last_file
77
+ if last_file:
78
+ try:
79
+ print(f"Deleting {last_file} ...")
80
+ os.remove(last_file)
81
+ except Exception as e:
82
+ print("Delete error:", e)
83
+ finally:
84
+ last_file = None
85
+ return gr.update(value=None), gr.update(value=None)
86
+
87
+ # --------------------
88
+ # Model builder (không gọi CUDA ở ngoài startup; mọi thứ phụ thuộc DEVICE/USE_HALF)
89
+ # --------------------
90
+ def get_model_and_paths(model_name, denoise_strength):
91
+ """Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần)."""
92
+ if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'):
93
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
94
  netscale = 4
95
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \
96
+ if model_name == 'RealESRGAN_x4plus' else \
97
+ ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
98
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
99
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
100
  netscale = 4
101
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
102
+ elif model_name == 'RealESRGAN_x2plus':
103
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
104
  netscale = 2
105
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
106
+ elif model_name == 'realesr-general-x4v3':
107
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
108
  netscale = 4
109
  file_url = [
110
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
111
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
112
  ]
113
+ else:
114
+ raise ValueError(f"Unsupported model: {model_name}")
115
 
116
+ # tải trọng số (nếu chưa có)
117
  model_path = os.path.join('weights', model_name + '.pth')
118
  if not os.path.isfile(model_path):
119
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
120
  for url in file_url:
121
+ model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'),
122
+ progress=True, file_name=None)
 
123
 
124
+ # dni (chỉ riêng general-x4v3)
125
  dni_weight = None
126
  if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
127
  wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
128
  model_path = [model_path, wdn_model_path]
129
  dni_weight = [denoise_strength, 1 - denoise_strength]
130
 
131
+ return model, netscale, model_path, dni_weight
132
+
133
+ def get_upsampler(model_name, denoise_strength):
134
+ """Khởi tạo/cached RealESRGANer theo device & half hiện hành."""
135
+ key = (model_name, float(denoise_strength), DEVICE, USE_HALF)
136
+ if key in UPSAMPLER_CACHE:
137
+ return UPSAMPLER_CACHE[key]
138
+
139
+ model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength)
140
+
141
+ # Cấu hình theo thiết bị
142
+ # - half=True khi GPU; False khi CPU
143
+ # - gpu_id=0 khi GPU; None khi CPU
144
+ half_flag = bool(USE_HALF)
145
+ gpu_id = 0 if DEVICE == "cuda" else None
146
+
147
  upsampler = RealESRGANer(
148
  scale=netscale,
149
  model_path=model_path,
 
152
  tile=0,
153
  tile_pad=10,
154
  pre_pad=10,
155
+ half=half_flag,
156
+ gpu_id=gpu_id
157
+ )
158
+ UPSAMPLER_CACHE[key] = upsampler
159
+ return upsampler
160
+
161
+ def get_face_enhancer(upsampler, outscale):
162
+ key = (int(outscale), DEVICE, USE_HALF)
163
+ if key in GFPGAN_FACE_ENHANCER:
164
+ return GFPGAN_FACE_ENHANCER[key]
165
+ from gfpgan import GFPGANer
166
+ face_enhancer = GFPGANer(
167
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
168
+ upscale=int(outscale),
169
+ arch='clean',
170
+ channel_multiplier=2,
171
+ bg_upsampler=upsampler
172
  )
173
+ GFPGAN_FACE_ENHANCER[key] = face_enhancer
174
+ return face_enhancer
175
+
176
+ # --------------------
177
+ # Inference (đánh dấu @spaces.GPU vì có thể chạy trên GPU)
178
+ # --------------------
179
+ @spaces.GPU
180
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
181
+ """Real-ESRGAN restore/upscale."""
182
+ if not img:
183
+ return
184
+
185
+ upsampler = get_upsampler(model_name, denoise_strength)
186
 
187
+ # PIL -> cv2 BGRA
 
 
 
 
 
 
 
 
 
 
188
  cv_img = numpy.array(img)
189
+ img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
190
 
 
191
  try:
192
  if face_enhance:
193
+ face_enhancer = get_face_enhancer(upsampler, outscale)
194
+ _, _, output = face_enhancer.enhance(
195
+ img_bgra, has_aligned=False, only_center_face=False, paste_back=True
196
+ )
197
  else:
198
+ output, _ = upsampler.enhance(img_bgra, outscale=int(outscale))
199
  except RuntimeError as error:
200
+ # Gợi ý tự động giảm tile nếu OOM
201
  print('Error', error)
202
+ return None
203
  else:
204
+ extension = 'png' if img_mode == 'RGBA' else 'jpg'
 
 
 
 
 
205
  out_filename = f"output_{rnd_string(8)}.{extension}"
206
  cv2.imwrite(out_filename, output)
207
  global last_file
208
  last_file = out_filename
209
  return out_filename
210
 
211
+ # --------------------
212
+ # UI
213
+ # --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def main():
 
215
  with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
216
+ gr.Markdown("## Image Upscaler")
 
 
 
 
217
 
218
  with gr.Accordion("Upscaling option"):
219
  with gr.Row():
220
+ model_name = gr.Dropdown(
221
+ label="Upscaler model",
222
+ choices=[
223
+ "RealESRGAN_x4plus",
224
+ "RealESRNet_x4plus",
225
+ "RealESRGAN_x4plus_anime_6B",
226
+ "RealESRGAN_x2plus",
227
+ "realesr-general-x4v3",
228
+ ],
229
+ value="RealESRGAN_x4plus_anime_6B",
230
+ show_label=True
231
  )
232
+ denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5)
233
+ outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True)
234
+ face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)")
235
+
236
  with gr.Row():
237
  with gr.Group():
238
  input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
 
242
  reset_btn = gr.Button("Remove images")
243
  restore_btn = gr.Button("Upscale")
244
 
 
245
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
246
  restore_btn.click(fn=realesrgan,
247
  inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
248
  outputs=output_image)
249
  reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
 
 
 
 
 
 
 
 
 
250
 
251
+ demo.launch(server_name="0.0.0.0", server_port=7860)
252
 
253
  if __name__ == "__main__":
254
+ # Gọi hàm startup để ZeroGPU cấp GPU ngay khi Space boot
255
+ gpu_startup()
256
  main()