comrender commited on
Commit
74168bc
Β·
verified Β·
1 Parent(s): 5d03bff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -31
app.py CHANGED
@@ -6,15 +6,70 @@ import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
 
9
  from diffusers import FluxImg2ImgPipeline
10
  from gradio_imageslider import ImageSlider
11
  from PIL import Image
12
  from huggingface_hub import snapshot_download
13
  import requests
14
 
15
- # ESRGAN imports
16
- from basicsr.archs.rrdbnet_arch import RRDBNet
17
- from basicsr.utils import img2tensor, tensor2img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  css = """
20
  #col-container {
@@ -73,12 +128,23 @@ esrgan_model = RRDBNet(
73
  num_grow_ch=32,
74
  scale=4
75
  )
 
 
76
  state_dict = torch.load(esrgan_path, map_location='cpu')
77
  if 'params_ema' in state_dict:
78
  state_dict = state_dict['params_ema']
79
  elif 'params' in state_dict:
80
  state_dict = state_dict['params']
81
- esrgan_model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
82
  esrgan_model.eval()
83
 
84
  print("βœ… All models loaded successfully!")
@@ -114,18 +180,21 @@ def prepare_image(image, max_size=MAX_INPUT_SIZE):
114
  return image
115
 
116
 
117
- def esrgan_upscale(image):
118
  """Upscale image 4x using ESRGAN"""
119
- # Convert PIL to tensor
120
  img_np = np.array(image).astype(np.float32) / 255.
121
- img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
 
122
 
123
  # Upscale
124
  with torch.no_grad():
125
- output = esrgan_model(img_tensor.unsqueeze(0).cpu())
126
-
127
- # Convert back to PIL
128
- output_np = tensor2img(output.squeeze(0), rgb2bgr=False, min_max=(0, 1))
 
 
129
  return Image.fromarray(output_np)
130
 
131
 
@@ -159,27 +228,16 @@ def enhance_image(
159
  input_image = prepare_image(input_image)
160
  original_size = input_image.size
161
 
162
- # Step 1: ESRGAN upscale (4x) on CPU
163
  gr.Info("πŸ” Upscaling with ESRGAN 4x...")
164
- with torch.no_grad():
165
- # Move ESRGAN to GPU for faster processing
166
- esrgan_model.to("cuda")
167
-
168
- # Convert image for ESRGAN
169
- img_np = np.array(input_image).astype(np.float32) / 255.
170
- img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
171
- img_tensor = img_tensor.unsqueeze(0).to("cuda")
172
-
173
- # Upscale
174
- output_tensor = esrgan_model(img_tensor)
175
-
176
- # Convert back to PIL
177
- output_np = tensor2img(output_tensor.squeeze(0).cpu(), rgb2bgr=False, min_max=(0, 1))
178
- upscaled_image = Image.fromarray(output_np)
179
-
180
- # Move ESRGAN back to CPU to free memory
181
- esrgan_model.to("cpu")
182
- torch.cuda.empty_cache()
183
 
184
  # Ensure dimensions are multiples of 16 for FLUX
185
  w, h = upscaled_image.size
 
6
  import numpy as np
7
  import spaces
8
  import torch
9
+ import torch.nn as nn
10
  from diffusers import FluxImg2ImgPipeline
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
 
16
+ # Minimal ESRGAN implementation (without basicsr dependency)
17
+ class ResidualDenseBlock(nn.Module):
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
26
+
27
+ def forward(self, x):
28
+ x1 = self.lrelu(self.conv1(x))
29
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
30
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
31
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
32
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
33
+ return x5 * 0.2 + x
34
+
35
+ class RRDB(nn.Module):
36
+ def __init__(self, num_feat, num_grow_ch=32):
37
+ super(RRDB, self).__init__()
38
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
39
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
40
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
41
+
42
+ def forward(self, x):
43
+ out = self.rdb1(x)
44
+ out = self.rdb2(out)
45
+ out = self.rdb3(out)
46
+ return out * 0.2 + x
47
+
48
+ class RRDBNet(nn.Module):
49
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
50
+ super(RRDBNet, self).__init__()
51
+ self.scale = scale
52
+
53
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
54
+ self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)])
55
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
56
+
57
+ # Upsampling
58
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
59
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
60
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
61
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
62
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
63
+
64
+ def forward(self, x):
65
+ fea = self.conv_first(x)
66
+ trunk = self.conv_body(self.body(fea))
67
+ fea = fea + trunk
68
+
69
+ fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
70
+ fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
71
+ out = self.conv_last(self.lrelu(self.conv_hr(fea)))
72
+ return out
73
 
74
  css = """
75
  #col-container {
 
128
  num_grow_ch=32,
129
  scale=4
130
  )
131
+
132
+ # Load state dict
133
  state_dict = torch.load(esrgan_path, map_location='cpu')
134
  if 'params_ema' in state_dict:
135
  state_dict = state_dict['params_ema']
136
  elif 'params' in state_dict:
137
  state_dict = state_dict['params']
138
+
139
+ # Clean state dict keys if needed
140
+ cleaned_state_dict = {}
141
+ for k, v in state_dict.items():
142
+ if k.startswith('module.'):
143
+ cleaned_state_dict[k[7:]] = v
144
+ else:
145
+ cleaned_state_dict[k] = v
146
+
147
+ esrgan_model.load_state_dict(cleaned_state_dict, strict=False)
148
  esrgan_model.eval()
149
 
150
  print("βœ… All models loaded successfully!")
 
180
  return image
181
 
182
 
183
+ def esrgan_upscale(image, model, device='cuda'):
184
  """Upscale image 4x using ESRGAN"""
185
+ # Prepare image
186
  img_np = np.array(image).astype(np.float32) / 255.
187
+ img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW
188
+ img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device)
189
 
190
  # Upscale
191
  with torch.no_grad():
192
+ output = model(img_tensor)
193
+ output = output.squeeze(0).cpu().clamp(0, 1)
194
+ output_np = output.numpy()
195
+ output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC
196
+ output_np = (output_np * 255).astype(np.uint8)
197
+
198
  return Image.fromarray(output_np)
199
 
200
 
 
228
  input_image = prepare_image(input_image)
229
  original_size = input_image.size
230
 
231
+ # Step 1: ESRGAN upscale (4x) on GPU
232
  gr.Info("πŸ” Upscaling with ESRGAN 4x...")
233
+
234
+ # Move ESRGAN to GPU for faster processing
235
+ esrgan_model.to("cuda")
236
+ upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda")
237
+
238
+ # Move ESRGAN back to CPU to free memory
239
+ esrgan_model.to("cpu")
240
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  # Ensure dimensions are multiples of 16 for FLUX
243
  w, h = upscaled_image.size