comrender commited on
Commit
1a431a3
Β·
verified Β·
1 Parent(s): 3bb8a2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -303
app.py CHANGED
@@ -1,75 +1,26 @@
1
- import os
2
  import random
3
  import warnings
4
- import gc
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
- import torch.nn as nn
10
  from diffusers import FluxImg2ImgPipeline
 
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
 
16
- # Minimal ESRGAN implementation (without basicsr dependency)
17
- class ResidualDenseBlock(nn.Module):
18
- def __init__(self, num_feat=64, num_grow_ch=32):
19
- super(ResidualDenseBlock, self).__init__()
20
- self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
- self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
- self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
- self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
- self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
26
-
27
- def forward(self, x):
28
- x1 = self.lrelu(self.conv1(x))
29
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
30
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
31
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
32
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
33
- return x5 * 0.2 + x
34
-
35
- class RRDB(nn.Module):
36
- def __init__(self, num_feat, num_grow_ch=32):
37
- super(RRDB, self).__init__()
38
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
39
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
40
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
41
-
42
- def forward(self, x):
43
- out = self.rdb1(x)
44
- out = self.rdb2(out)
45
- out = self.rdb3(out)
46
- return out * 0.2 + x
47
-
48
- class RRDBNet(nn.Module):
49
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
50
- super(RRDBNet, self).__init__()
51
- self.scale = scale
52
-
53
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
54
- self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)])
55
- self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
56
-
57
- # Upsampling
58
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
59
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
60
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
61
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
62
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
63
-
64
- def forward(self, x):
65
- fea = self.conv_first(x)
66
- trunk = self.conv_body(self.body(fea))
67
- fea = fea + trunk
68
-
69
- fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
70
- fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
71
- out = self.conv_last(self.lrelu(self.conv_hr(fea)))
72
- return out
73
 
74
  css = """
75
  #col-container {
@@ -82,10 +33,14 @@ css = """
82
  }
83
  """
84
 
 
 
 
 
85
  # Get HuggingFace token
86
  huggingface_token = os.getenv("HF_TOKEN")
87
 
88
- # Download FLUX model if not already cached
89
  print("πŸ“₯ Downloading FLUX model...")
90
  model_path = snapshot_download(
91
  repo_id="black-forest-labs/FLUX.1-dev",
@@ -95,276 +50,324 @@ model_path = snapshot_download(
95
  token=huggingface_token,
96
  )
97
 
98
- # Load FLUX pipeline on CPU initially
99
- print("πŸ“₯ Loading FLUX Img2Img pipeline...")
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  pipe = FluxImg2ImgPipeline.from_pretrained(
101
  model_path,
102
- torch_dtype=torch.bfloat16,
103
- use_safetensors=True
104
  )
105
-
106
- # Enable memory optimizations
107
  pipe.enable_vae_tiling()
108
  pipe.enable_vae_slicing()
109
- pipe.vae.enable_tiling()
110
- pipe.vae.enable_slicing()
111
-
112
- # Download and load ESRGAN 4x-UltraSharp model
113
- print("πŸ“₯ Loading ESRGAN 4x-UltraSharp...")
114
- esrgan_path = "4x-UltraSharp.pth"
115
- if not os.path.exists(esrgan_path):
116
- print("Downloading ESRGAN model...")
117
- url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
118
- response = requests.get(url)
119
- with open(esrgan_path, "wb") as f:
120
- f.write(response.content)
121
-
122
- # Initialize ESRGAN model
123
- esrgan_model = RRDBNet(
124
- num_in_ch=3,
125
- num_out_ch=3,
126
- num_feat=64,
127
- num_block=23,
128
- num_grow_ch=32,
129
- scale=4
130
- )
131
-
132
- # Load state dict
133
- state_dict = torch.load(esrgan_path, map_location='cpu')
134
- if 'params_ema' in state_dict:
135
- state_dict = state_dict['params_ema']
136
- elif 'params' in state_dict:
137
- state_dict = state_dict['params']
138
-
139
- # Clean state dict keys if needed
140
- cleaned_state_dict = {}
141
- for k, v in state_dict.items():
142
- if k.startswith('module.'):
143
- cleaned_state_dict[k[7:]] = v
144
- else:
145
- cleaned_state_dict[k] = v
146
-
147
- esrgan_model.load_state_dict(cleaned_state_dict, strict=False)
148
- esrgan_model.eval()
149
 
150
  print("βœ… All models loaded successfully!")
151
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  MAX_SEED = 1000000
153
- MAX_INPUT_SIZE = 512 # Max input size before upscaling
154
 
155
 
156
- def make_multiple_16(n):
157
- """Round to nearest multiple of 16 for FLUX requirements"""
158
- return ((n + 15) // 16) * 16
 
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- def truncate_prompt(prompt, max_tokens=75):
162
- """Truncate prompt to avoid CLIP token limit (77 tokens)"""
163
- if not prompt:
164
- return ""
165
-
166
- # Simple truncation by character count (rough approximation)
167
- if len(prompt) > 250: # ~75 tokens
168
- return prompt[:250] + "..."
169
- return prompt
170
 
 
 
 
 
 
171
 
172
- def prepare_image(image, max_size=MAX_INPUT_SIZE):
173
- """Prepare image for processing"""
174
- w, h = image.size
175
-
176
- # Limit input size
177
- if w > max_size or h > max_size:
178
- image.thumbnail((max_size, max_size), Image.LANCZOS)
179
-
180
- return image
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- def esrgan_upscale(image, model, device='cuda', upscale_factor=4):
184
- """Upscale image using ESRGAN with variable factor support"""
185
- orig_w, orig_h = image.size
186
- pre_resize_factor = upscale_factor / 4.0
187
- low_res_w = int(orig_w * pre_resize_factor)
188
- low_res_h = int(orig_h * pre_resize_factor)
189
- if low_res_w < 1 or low_res_h < 1:
190
- raise ValueError("Upscale factor too small for image size")
191
-
192
- low_res_image = image.resize((low_res_w, low_res_h), Image.BICUBIC) # Changed to BICUBIC for better match to training degradation
193
-
194
- # Prepare image
195
- img_np = np.array(low_res_image).astype(np.float32) / 255.
196
- img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW
197
- img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device)
198
-
199
- # Upscale
200
  with torch.no_grad():
201
- output = model(img_tensor)
202
- output = output.squeeze(0).cpu().clamp(0, 1)
203
- output_np = output.numpy()
204
- output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC
205
- output_np = (output_np * 255).astype(np.uint8)
206
-
207
- upscaled = Image.fromarray(output_np)
208
-
209
- # Resize back to exact target size if needed (due to rounding)
210
- target_w = int(orig_w * upscale_factor)
211
- target_h = int(orig_h * upscale_factor)
212
- if upscaled.size != (target_w, target_h):
213
- upscaled = upscaled.resize((target_w, target_h), Image.BICUBIC) # Changed to BICUBIC
214
-
215
- return upscaled
216
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- @spaces.GPU(duration=120) # Increased to 120 seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def enhance_image(
220
- input_image,
221
- prompt,
222
  seed,
223
  randomize_seed,
224
  num_inference_steps,
225
- denoising_strength,
226
  upscale_factor,
 
 
 
227
  progress=gr.Progress(track_tqdm=True),
228
  ):
229
  """Main enhancement function"""
230
- if input_image is None:
231
- raise gr.Error("Please upload an image")
232
-
233
- # Clear memory
234
- torch.cuda.empty_cache()
235
- gc.collect()
236
-
237
  try:
238
- # Randomize seed if needed
239
- if randomize_seed:
240
- seed = random.randint(0, MAX_SEED)
241
-
242
- # Prepare and validate prompt
243
- prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed")
244
-
245
- # Prepare input image
246
- input_image = prepare_image(input_image)
247
- original_size = input_image.size
248
-
249
- # Step 1: ESRGAN upscale on GPU
250
- gr.Info(f"πŸ” Upscaling with ESRGAN x{upscale_factor}...")
251
-
252
- # Move ESRGAN to GPU for faster processing
253
- esrgan_model.to("cuda")
254
- upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda", upscale_factor=upscale_factor)
255
-
256
- # Move ESRGAN back to CPU to free memory
257
- esrgan_model.to("cpu")
258
- torch.cuda.empty_cache()
259
-
260
- # Ensure dimensions are multiples of 16 for FLUX
261
- w, h = upscaled_image.size
262
- new_w = make_multiple_16(w)
263
- new_h = make_multiple_16(h)
264
-
265
- if new_w != w or new_h != h:
266
- # Pad image to meet requirements
267
- padded = Image.new('RGB', (new_w, new_h))
268
- padded.paste(upscaled_image, (0, 0))
269
- upscaled_image = padded
270
-
271
- # Step 2: FLUX enhancement
272
- gr.Info("🎨 Enhancing with FLUX...")
273
-
274
- # Move pipeline to GPU
275
- pipe.to("cuda")
276
-
277
- # Generate with FLUX
278
- generator = torch.Generator(device="cuda").manual_seed(seed)
279
-
280
- with torch.inference_mode():
281
- result = pipe(
282
- prompt=prompt,
283
- image=upscaled_image,
284
- strength=denoising_strength,
285
- num_inference_steps=num_inference_steps,
286
- guidance_scale=3.5, # Recommended for FLUX.1-dev to reduce artifacts
287
- height=new_h,
288
- width=new_w,
289
- generator=generator,
290
- ).images[0]
291
-
292
- # Crop back if we padded
293
- if new_w != w or new_h != h:
294
- result = result.crop((0, 0, w, h))
295
-
296
- # Move pipeline back to CPU
297
- pipe.to("cpu")
298
- torch.cuda.empty_cache()
299
- gc.collect()
300
-
301
- # Prepare images for slider (before/after)
302
- input_resized = input_image.resize(result.size, Image.LANCZOS)
303
-
304
- gr.Info("βœ… Enhancement complete!")
305
- return [input_resized, result], seed
306
-
307
  except Exception as e:
308
- # Cleanup on error
309
- pipe.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  esrgan_model.to("cpu")
311
- torch.cuda.empty_cache()
312
- gc.collect()
313
- raise gr.Error(f"Enhancement failed: {str(e)}")
314
 
315
 
316
  # Create Gradio interface
317
- with gr.Blocks(css=css) as demo:
318
  gr.HTML("""
319
  <div class="main-header">
320
- <h1>πŸš€ Flux Dev Ultimate Upscaler</h1>
321
- <p>Upload an image to upscale 2-4x with ESRGAN and enhance with FLUX</p>
322
- <p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β†’ Output: up to 2048x2048</p>
323
  </div>
324
- """)
325
-
326
  with gr.Row():
327
  with gr.Column(scale=1):
328
- # Input section
329
- input_image = gr.Image(
330
- label="Input Image",
331
- type="pil",
332
- height=256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
- prompt = gr.Textbox(
336
- label="Describe image with prompt",
337
- placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')",
338
- value="high quality, ultra detailed, sharp",
339
  lines=2
340
  )
341
 
342
- # Advanced Settings (always open)
 
343
  upscale_factor = gr.Slider(
344
- label="Upscale Ratio",
345
- minimum=2,
346
  maximum=4,
347
  step=1,
348
- value=4,
349
- info="Choose upscale factor (2x, 3x, 4x). Use 4x for best results; lower may cause color artifacts."
350
  )
351
 
352
  num_inference_steps = gr.Slider(
353
- label="Enhancement Steps",
354
- minimum=10,
355
- maximum=25,
356
  step=1,
357
- value=20, # Increased default for better denoising
358
  info="More steps = better quality but slower"
359
  )
360
 
361
  denoising_strength = gr.Slider(
362
- label="Creativity (Denoising)",
363
- minimum=0.1,
364
- maximum=0.6,
365
  step=0.05,
366
- value=0.35,
367
- info="Higher = more changes to the image"
368
  )
369
 
370
  with gr.Row():
@@ -372,58 +375,124 @@ with gr.Blocks(css=css) as demo:
372
  label="Randomize seed",
373
  value=True
374
  )
375
- seed = gr.Number(
376
  label="Seed",
377
- value=42
 
 
 
 
378
  )
379
 
380
  enhance_btn = gr.Button(
381
- "Upscale",
382
  variant="primary",
383
  size="lg"
384
  )
385
-
386
- with gr.Column(scale=2):
387
- # Output section
 
388
  result_slider = ImageSlider(
389
  type="pil",
390
- label="Before / After",
391
- interactive=False,
392
- height=512
393
- )
394
-
395
- used_seed = gr.Number(
396
- label="Seed Used",
397
- interactive=False,
398
- visible=False
399
  )
400
-
401
  # Event handler
402
  enhance_btn.click(
403
  fn=enhance_image,
404
  inputs=[
405
  input_image,
406
- prompt,
407
  seed,
408
  randomize_seed,
409
  num_inference_steps,
410
- denoising_strength,
411
  upscale_factor,
 
 
 
412
  ],
413
- outputs=[result_slider, used_seed]
414
  )
415
 
416
  gr.HTML("""
417
- <div style="margin-top: 2rem; text-align: center; color: #666;">
418
- <p>πŸ“Œ Pipeline: ESRGAN 2-4x-UltraSharp β†’ FLUX Dev Enhancement</p>
419
- <p>⚑ Optimized for ZeroGPU with automatic memory management</p>
420
- <p>πŸ“Œ Note: User is responsible for obtaining commercial license from Flux Dev if using image commercially under their license.</p>
421
  </div>
422
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  if __name__ == "__main__":
425
- demo.queue(max_size=3).launch(
426
- share=False,
427
- server_name="0.0.0.0",
428
- server_port=7860
429
- )
 
1
+ import logging
2
  import random
3
  import warnings
4
+ import os
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
 
9
  from diffusers import FluxImg2ImgPipeline
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
 
16
+ # For ESRGAN (requires pip install basicsr gfpgan)
17
+ try:
18
+ from basicsr.archs.rrdbnet_arch import RRDBNet
19
+ from basicsr.utils import img2tensor, tensor2img
20
+ USE_ESRGAN = True
21
+ except ImportError:
22
+ USE_ESRGAN = False
23
+ warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  css = """
26
  #col-container {
 
33
  }
34
  """
35
 
36
+ # Device setup - Default to CPU, let runtime handle GPU
37
+ power_device = "ZeroGPU"
38
+ device = "cpu"
39
+
40
  # Get HuggingFace token
41
  huggingface_token = os.getenv("HF_TOKEN")
42
 
43
+ # Download FLUX model
44
  print("πŸ“₯ Downloading FLUX model...")
45
  model_path = snapshot_download(
46
  repo_id="black-forest-labs/FLUX.1-dev",
 
50
  token=huggingface_token,
51
  )
52
 
53
+ # Load Florence-2 model for image captioning on CPU
54
+ print("πŸ“₯ Loading Florence-2 model...")
55
+ florence_model = AutoModelForCausalLM.from_pretrained(
56
+ "microsoft/Florence-2-large",
57
+ torch_dtype=torch.float32, # Force CPU dtype
58
+ trust_remote_code=True,
59
+ attn_implementation="eager"
60
+ ).to(device)
61
+ florence_processor = AutoProcessor.from_pretrained(
62
+ "microsoft/Florence-2-large",
63
+ trust_remote_code=True
64
+ )
65
+
66
+ # Load FLUX Img2Img pipeline on CPU
67
+ print("πŸ“₯ Loading FLUX Img2Img...")
68
  pipe = FluxImg2ImgPipeline.from_pretrained(
69
  model_path,
70
+ torch_dtype=torch.float32 # Force CPU dtype
 
71
  )
 
 
72
  pipe.enable_vae_tiling()
73
  pipe.enable_vae_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  print("βœ… All models loaded successfully!")
76
 
77
+ # Download ESRGAN model if using
78
+ if USE_ESRGAN:
79
+ esrgan_path = "4x-UltraSharp.pth"
80
+ if not os.path.exists(esrgan_path):
81
+ url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
82
+ with open(esrgan_path, "wb") as f:
83
+ f.write(requests.get(url).content)
84
+ esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
85
+ state_dict = torch.load(esrgan_path)['params_ema']
86
+ esrgan_model.load_state_dict(state_dict)
87
+ esrgan_model.eval()
88
+
89
  MAX_SEED = 1000000
90
+ MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
91
 
92
 
93
+ def generate_caption(image):
94
+ """Generate detailed caption using Florence-2"""
95
+ try:
96
+ task_prompt = "<MORE_DETAILED_CAPTION>"
97
+ prompt = task_prompt
98
 
99
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
100
+
101
+ generated_ids = florence_model.generate(
102
+ input_ids=inputs["input_ids"],
103
+ pixel_values=inputs["pixel_values"],
104
+ max_new_tokens=1024,
105
+ num_beams=3,
106
+ do_sample=True,
107
+ )
108
+
109
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
110
+ parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
111
+
112
+ caption = parsed_answer[task_prompt]
113
+ return caption
114
+ except Exception as e:
115
+ print(f"Caption generation failed: {e}")
116
+ return "a high quality detailed image"
117
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def process_input(input_image, upscale_factor):
120
+ """Process input image and handle size constraints"""
121
+ w, h = input_image.size
122
+ w_original, h_original = w, h
123
+ aspect_ratio = w / h
124
 
125
+ was_resized = False
 
 
 
 
 
 
 
 
126
 
127
+ if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
128
+ warnings.warn(
129
+ f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
130
+ )
131
+ gr.Info(
132
+ f"Requested output image is too large. Resizing input to fit within pixel budget."
133
+ )
134
+ target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
135
+ scale = (target_input_pixels / (w * h)) ** 0.5
136
+ new_w = int(w * scale) - int(w * scale) % 8
137
+ new_h = int(h * scale) - int(h * scale) % 8
138
+ input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
139
+ was_resized = True
140
 
141
+ return input_image, w_original, h_original, was_resized
142
+
143
+
144
+ def load_image_from_url(url):
145
+ """Load image from URL"""
146
+ try:
147
+ response = requests.get(url, stream=True)
148
+ response.raise_for_status()
149
+ return Image.open(response.raw)
150
+ except Exception as e:
151
+ raise gr.Error(f"Failed to load image from URL: {e}")
152
+
153
+
154
+ def esrgan_upscale(image, scale=4):
155
+ if not USE_ESRGAN:
156
+ return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
157
+ img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
158
  with torch.no_grad():
159
+ output = esrgan_model(img.unsqueeze(0)).squeeze()
160
+ output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
161
+ return Image.fromarray(output_img)
162
+
163
+
164
+ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
165
+ """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
166
+ w, h = image.size
167
+ output = image.copy() # Start with the control image
168
+
169
+ for x in range(0, w, tile_size - overlap):
170
+ for y in range(0, h, tile_size - overlap):
171
+ tile_w = min(tile_size, w - x)
172
+ tile_h = min(tile_size, h - y)
173
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
174
 
175
+ # Run Flux on tile
176
+ gen_tile = pipe(
177
+ prompt=prompt,
178
+ image=tile,
179
+ strength=strength,
180
+ num_inference_steps=steps,
181
+ guidance_scale=guidance,
182
+ height=tile_h,
183
+ width=tile_w,
184
+ generator=generator,
185
+ ).images[0]
186
 
187
+ # Paste with blending if overlap
188
+ if overlap > 0:
189
+ paste_box = (x, y, x + tile_w, y + tile_h)
190
+ if x > 0 or y > 0:
191
+ # Simple linear blend on overlaps
192
+ mask = Image.new('L', (tile_w, tile_h), 255)
193
+ if x > 0:
194
+ for i in range(overlap):
195
+ for j in range(tile_h):
196
+ mask.putpixel((i, j), int(255 * (i / overlap)))
197
+ if y > 0:
198
+ for i in range(tile_w):
199
+ for j in range(overlap):
200
+ mask.putpixel((i, j), int(255 * (j / overlap)))
201
+ output.paste(gen_tile, paste_box, mask)
202
+ else:
203
+ output.paste(gen_tile, paste_box)
204
+ else:
205
+ output.paste(gen_tile, (x, y))
206
+
207
+ return output
208
+
209
+
210
+ @spaces.GPU(duration=120)
211
  def enhance_image(
212
+ image_input,
213
+ image_url,
214
  seed,
215
  randomize_seed,
216
  num_inference_steps,
 
217
  upscale_factor,
218
+ denoising_strength,
219
+ use_generated_caption,
220
+ custom_prompt,
221
  progress=gr.Progress(track_tqdm=True),
222
  ):
223
  """Main enhancement function"""
224
+ # Move models to GPU with fallback to CPU
 
 
 
 
 
 
225
  try:
226
+ device = "cuda"
227
+ pipe.to(device)
228
+ florence_model.to(device)
229
+ if USE_ESRGAN:
230
+ esrgan_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  except Exception as e:
232
+ print(f"GPU error: {e}, falling back to CPU")
233
+ device = "cpu"
234
+
235
+ # Handle image input
236
+ if image_input is not None:
237
+ input_image = image_input
238
+ elif image_url:
239
+ input_image = load_image_from_url(image_url)
240
+ else:
241
+ raise gr.Error("Please provide an image (upload or URL)")
242
+
243
+ if randomize_seed:
244
+ seed = random.randint(0, MAX_SEED)
245
+
246
+ true_input_image = input_image
247
+
248
+ # Process input image
249
+ input_image, w_original, h_original, was_resized = process_input(
250
+ input_image, upscale_factor
251
+ )
252
+
253
+ # Generate caption if requested
254
+ if use_generated_caption:
255
+ gr.Info("πŸ” Generating image caption...")
256
+ generated_caption = generate_caption(input_image)
257
+ prompt = generated_caption
258
+ else:
259
+ prompt = custom_prompt if custom_prompt.strip() else ""
260
+
261
+ generator = torch.Generator(device=device).manual_seed(seed)
262
+
263
+ gr.Info("πŸš€ Upscaling image...")
264
+
265
+ # Initial upscale
266
+ if USE_ESRGAN and upscale_factor == 4:
267
+ control_image = esrgan_upscale(input_image, upscale_factor)
268
+ else:
269
+ w, h = input_image.size
270
+ control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
271
+
272
+ # Tiled Flux Img2Img for refinement
273
+ image = tiled_flux_img2img(
274
+ pipe,
275
+ prompt,
276
+ control_image,
277
+ denoising_strength,
278
+ num_inference_steps,
279
+ 1.0, # Hardcoded guidance_scale to 1
280
+ generator,
281
+ tile_size=1024,
282
+ overlap=32
283
+ )
284
+
285
+ if was_resized:
286
+ gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
287
+ image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
288
+
289
+ # Resize input image to match output size for slider alignment
290
+ resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
291
+
292
+ # Move back to CPU to release GPU
293
+ pipe.to("cpu")
294
+ florence_model.to("cpu")
295
+ if USE_ESRGAN:
296
  esrgan_model.to("cpu")
297
+
298
+ return [resized_input, image]
 
299
 
300
 
301
  # Create Gradio interface
302
+ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
303
  gr.HTML("""
304
  <div class="main-header">
305
+ <h1>🎨 AI Image Upscaler</h1>
306
+ <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
307
+ <p>Currently running on <strong>{}</strong></p>
308
  </div>
309
+ """.format(power_device))
310
+
311
  with gr.Row():
312
  with gr.Column(scale=1):
313
+ gr.HTML("<h3>πŸ“€ Input</h3>")
314
+
315
+ with gr.Tabs():
316
+ with gr.TabItem("πŸ“ Upload Image"):
317
+ input_image = gr.Image(
318
+ label="Upload Image",
319
+ type="pil",
320
+ height=200 # Made smaller
321
+ )
322
+
323
+ with gr.TabItem("πŸ”— Image URL"):
324
+ image_url = gr.Textbox(
325
+ label="Image URL",
326
+ placeholder="https://example.com/image.jpg",
327
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
328
+ )
329
+
330
+ gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
331
+
332
+ use_generated_caption = gr.Checkbox(
333
+ label="Use AI-generated caption (Florence-2)",
334
+ value=True,
335
+ info="Generate detailed caption automatically"
336
  )
337
 
338
+ custom_prompt = gr.Textbox(
339
+ label="Custom Prompt (optional)",
340
+ placeholder="Enter custom prompt or leave empty for generated caption",
 
341
  lines=2
342
  )
343
 
344
+ gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
345
+
346
  upscale_factor = gr.Slider(
347
+ label="Upscale Factor",
348
+ minimum=1,
349
  maximum=4,
350
  step=1,
351
+ value=2,
352
+ info="How much to upscale the image"
353
  )
354
 
355
  num_inference_steps = gr.Slider(
356
+ label="Number of Inference Steps",
357
+ minimum=8,
358
+ maximum=50,
359
  step=1,
360
+ value=25,
361
  info="More steps = better quality but slower"
362
  )
363
 
364
  denoising_strength = gr.Slider(
365
+ label="Denoising Strength",
366
+ minimum=0.0,
367
+ maximum=1.0,
368
  step=0.05,
369
+ value=0.3,
370
+ info="Controls how much the image is transformed"
371
  )
372
 
373
  with gr.Row():
 
375
  label="Randomize seed",
376
  value=True
377
  )
378
+ seed = gr.Slider(
379
  label="Seed",
380
+ minimum=0,
381
+ maximum=MAX_SEED,
382
+ step=1,
383
+ value=42,
384
+ interactive=True
385
  )
386
 
387
  enhance_btn = gr.Button(
388
+ "πŸš€ Upscale Image",
389
  variant="primary",
390
  size="lg"
391
  )
392
+
393
+ with gr.Column(scale=2): # Larger scale for results
394
+ gr.HTML("<h3>πŸ“Š Results</h3>")
395
+
396
  result_slider = ImageSlider(
397
  type="pil",
398
+ interactive=False, # Disable interactivity to prevent uploads
399
+ height=600, # Made larger
400
+ elem_id="result_slider",
401
+ label=None # Remove default label
 
 
 
 
 
402
  )
403
+
404
  # Event handler
405
  enhance_btn.click(
406
  fn=enhance_image,
407
  inputs=[
408
  input_image,
409
+ image_url,
410
  seed,
411
  randomize_seed,
412
  num_inference_steps,
 
413
  upscale_factor,
414
+ denoising_strength,
415
+ use_generated_caption,
416
+ custom_prompt,
417
  ],
418
+ outputs=[result_slider]
419
  )
420
 
421
  gr.HTML("""
422
+ <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
423
+ <p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
 
 
424
  </div>
425
  """)
426
+
427
+ # Custom CSS for slider
428
+ gr.HTML("""
429
+ <style>
430
+ #result_slider .slider {
431
+ width: 100% !important;
432
+ max-width: inherit !important;
433
+ }
434
+ #result_slider img {
435
+ object-fit: contain !important;
436
+ width: 100% !important;
437
+ height: auto !important;
438
+ }
439
+ #result_slider .gr-button-tool {
440
+ display: none !important;
441
+ }
442
+ #result_slider .gr-button-undo {
443
+ display: none !important;
444
+ }
445
+ #result_slider .gr-button-clear {
446
+ display: none !important;
447
+ }
448
+ #result_slider .badge-container .badge {
449
+ display: none !important;
450
+ }
451
+ #result_slider .badge-container::before {
452
+ content: "Before";
453
+ position: absolute;
454
+ top: 10px;
455
+ left: 10px;
456
+ background: rgba(0,0,0,0.5);
457
+ color: white;
458
+ padding: 5px;
459
+ border-radius: 5px;
460
+ z-index: 10;
461
+ }
462
+ #result_slider .badge-container::after {
463
+ content: "After";
464
+ position: absolute;
465
+ top: 10px;
466
+ right: 10px;
467
+ background: rgba(0,0,0,0.5);
468
+ color: white;
469
+ padding: 5px;
470
+ border-radius: 5px;
471
+ z-index: 10;
472
+ }
473
+ #result_slider .fullscreen img {
474
+ object-fit: contain !important;
475
+ width: 100vw !important;
476
+ height: 100vh !important;
477
+ position: absolute;
478
+ top: 0;
479
+ left: 0;
480
+ }
481
+ </style>
482
+ """)
483
+
484
+ # JS to set slider default position to middle
485
+ gr.HTML("""
486
+ <script>
487
+ document.addEventListener('DOMContentLoaded', function() {
488
+ const sliderInput = document.querySelector('#result_slider input[type="range"]');
489
+ if (sliderInput) {
490
+ sliderInput.value = 50;
491
+ sliderInput.dispatchEvent(new Event('input'));
492
+ }
493
+ });
494
+ </script>
495
+ """)
496
 
497
  if __name__ == "__main__":
498
+ demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)