comrender commited on
Commit
b0a0a29
Β·
verified Β·
1 Parent(s): a6eb212

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -357
app.py CHANGED
@@ -12,15 +12,10 @@ from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
 
15
 
16
- # For ESRGAN (optional - will work without it)
17
- try:
18
- from basicsr.archs.rrdbnet_arch import RRDBNet
19
- from basicsr.utils import img2tensor, tensor2img
20
- USE_ESRGAN = True
21
- except ImportError:
22
- USE_ESRGAN = False
23
- warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
24
 
25
  css = """
26
  #col-container {
@@ -35,7 +30,7 @@ css = """
35
 
36
  # Device setup
37
  power_device = "ZeroGPU"
38
- device = "cpu" # Start on CPU, will move to GPU when needed
39
 
40
  # Get HuggingFace token
41
  huggingface_token = os.getenv("HF_TOKEN")
@@ -50,85 +45,88 @@ model_path = snapshot_download(
50
  token=huggingface_token,
51
  )
52
 
53
- # Load Florence-2 model for image captioning on CPU
54
  print("πŸ“₯ Loading Florence-2 model...")
55
  florence_model = AutoModelForCausalLM.from_pretrained(
56
  "microsoft/Florence-2-large",
57
- torch_dtype=torch.float32, # Use float32 on CPU to avoid dtype issues
58
  trust_remote_code=True,
59
  attn_implementation="eager"
60
- ).to(device)
 
61
  florence_processor = AutoProcessor.from_pretrained(
62
  "microsoft/Florence-2-large",
63
  trust_remote_code=True
64
  )
65
 
66
- # Load FLUX Img2Img pipeline on CPU
67
  print("πŸ“₯ Loading FLUX Img2Img...")
68
  pipe = FluxImg2ImgPipeline.from_pretrained(
69
  model_path,
70
- torch_dtype=torch.float32 # Start with float32 on CPU
71
  )
 
 
 
72
  pipe.enable_vae_tiling()
73
  pipe.enable_vae_slicing()
 
 
74
 
75
  print("βœ… All models loaded successfully!")
76
 
77
- # Download ESRGAN model if using
78
- if USE_ESRGAN:
79
- try:
80
- esrgan_path = "4x-UltraSharp.pth"
81
- if not os.path.exists(esrgan_path):
82
- url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
83
- print("πŸ“₯ Downloading ESRGAN model...")
84
- with open(esrgan_path, "wb") as f:
85
- f.write(requests.get(url).content)
86
- esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
87
- state_dict = torch.load(esrgan_path, map_location='cpu')['params_ema']
88
- esrgan_model.load_state_dict(state_dict)
89
- esrgan_model.eval()
90
- print("βœ… ESRGAN model loaded!")
91
- except Exception as e:
92
- print(f"Failed to load ESRGAN: {e}")
93
- USE_ESRGAN = False
94
-
95
  MAX_SEED = 1000000
96
- MAX_PIXEL_BUDGET = 8192 * 8192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  def make_multiple_16(n):
100
- """Round up to nearest multiple of 16"""
101
  return ((n + 15) // 16) * 16
102
 
103
 
104
  def generate_caption(image):
105
- """Generate detailed caption using Florence-2"""
106
  try:
107
- # Ensure model is on the correct device with correct dtype
108
- if florence_model.device.type == "cuda":
109
- florence_model.to(torch.float16)
110
-
111
  task_prompt = "<MORE_DETAILED_CAPTION>"
112
- prompt = task_prompt
113
-
 
 
 
114
  inputs = florence_processor(
115
- text=prompt,
116
  images=image,
117
  return_tensors="pt"
118
- ).to(florence_model.device)
119
-
120
- # Ensure dtype consistency
121
- if florence_model.device.type == "cuda":
122
- if hasattr(inputs, "pixel_values"):
123
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
124
-
125
- generated_ids = florence_model.generate(
126
- input_ids=inputs["input_ids"],
127
- pixel_values=inputs["pixel_values"],
128
- max_new_tokens=1024,
129
- num_beams=3,
130
- do_sample=True,
131
- )
132
 
133
  generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
134
  parsed_answer = florence_processor.post_process_generation(
@@ -138,213 +136,57 @@ def generate_caption(image):
138
  )
139
 
140
  caption = parsed_answer[task_prompt]
 
 
141
  return caption
 
142
  except Exception as e:
143
  print(f"Caption generation failed: {e}")
144
- return "a high quality detailed image"
145
 
146
 
147
  def process_input(input_image, upscale_factor):
148
- """Process input image and handle size constraints"""
149
  w, h = input_image.size
150
  w_original, h_original = w, h
151
 
152
  was_resized = False
153
 
 
154
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
155
- warnings.warn(
156
- f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
157
- )
158
- gr.Info(
159
- f"Requested output image is too large. Resizing input to fit within pixel budget."
160
- )
161
- target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
162
- scale = (target_input_pixels / (w * h)) ** 0.5
163
  new_w = make_multiple_16(int(w * scale))
164
  new_h = make_multiple_16(int(h * scale))
165
- input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
 
166
  was_resized = True
167
 
168
- return input_image, w_original, h_original, was_resized
169
-
170
-
171
- def load_image_from_url(url):
172
- """Load image from URL"""
173
- try:
174
- response = requests.get(url, stream=True)
175
- response.raise_for_status()
176
- return Image.open(response.raw)
177
- except Exception as e:
178
- raise gr.Error(f"Failed to load image from URL: {e}")
179
-
180
-
181
- def esrgan_upscale(image, scale=4):
182
- """Upscale image using ESRGAN or fallback to LANCZOS"""
183
- if not USE_ESRGAN:
184
- return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
185
-
186
- try:
187
- img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
188
- with torch.no_grad():
189
- # Move model to same device as image tensor
190
- if torch.cuda.is_available():
191
- esrgan_model.to("cuda")
192
- img = img.to("cuda")
193
- output = esrgan_model(img.unsqueeze(0)).squeeze()
194
- output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
195
- return Image.fromarray(output_img)
196
- except Exception as e:
197
- print(f"ESRGAN upscale failed: {e}, falling back to LANCZOS")
198
- return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
199
-
200
-
201
- def create_blend_mask(width, height, overlap, edge_x, edge_y):
202
- """Create a gradient blend mask for smooth tile transitions"""
203
- mask = Image.new('L', (width, height), 255)
204
- pixels = mask.load()
205
-
206
- # Horizontal blend (left edge)
207
- if edge_x and overlap > 0:
208
- for x in range(min(overlap, width)):
209
- alpha = x / overlap
210
- for y in range(height):
211
- pixels[x, y] = int(255 * alpha)
212
 
213
- # Vertical blend (top edge)
214
- if edge_y and overlap > 0:
215
- for y in range(min(overlap, height)):
216
- alpha = y / overlap
217
- for x in range(width):
218
- # Combine with existing alpha if both edges
219
- existing = pixels[x, y] / 255.0
220
- combined = min(existing, alpha)
221
- pixels[x, y] = int(255 * combined)
222
 
223
- return mask
224
 
225
 
226
- def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=64):
227
- """Tiled Img2Img to handle large images"""
228
- w, h = image.size
229
-
230
- # Ensure tile_size is divisible by 16
231
- tile_size = make_multiple_16(tile_size)
232
- overlap = make_multiple_16(overlap)
233
-
234
- # If image is small enough, process without tiling
235
- if w <= tile_size and h <= tile_size:
236
- # Ensure dimensions are divisible by 16
237
- new_w = make_multiple_16(w)
238
- new_h = make_multiple_16(h)
239
-
240
- if new_w != w or new_h != h:
241
- padded_image = Image.new('RGB', (new_w, new_h))
242
- padded_image.paste(image, (0, 0))
243
- else:
244
- padded_image = image
245
-
246
- result = pipe(
247
- prompt=prompt,
248
- image=padded_image,
249
- strength=strength,
250
- num_inference_steps=steps,
251
- guidance_scale=guidance,
252
- height=new_h,
253
- width=new_w,
254
- generator=generator,
255
- ).images[0]
256
-
257
- # Crop back to original size if padded
258
- if new_w != w or new_h != h:
259
- result = result.crop((0, 0, w, h))
260
-
261
- return result
262
-
263
- # Process with tiling for large images
264
- output = Image.new('RGB', (w, h))
265
-
266
- # Calculate tile positions
267
- tiles = []
268
- for y in range(0, h, tile_size - overlap):
269
- for x in range(0, w, tile_size - overlap):
270
- tile_w = min(tile_size, w - x)
271
- tile_h = min(tile_size, h - y)
272
-
273
- # Ensure tile dimensions are divisible by 16
274
- tile_w_padded = make_multiple_16(tile_w)
275
- tile_h_padded = make_multiple_16(tile_h)
276
-
277
- tiles.append({
278
- 'x': x,
279
- 'y': y,
280
- 'w': tile_w,
281
- 'h': tile_h,
282
- 'w_padded': tile_w_padded,
283
- 'h_padded': tile_h_padded,
284
- 'edge_x': x > 0,
285
- 'edge_y': y > 0
286
- })
287
-
288
- # Process each tile
289
- for i, tile_info in enumerate(tiles):
290
- print(f"Processing tile {i+1}/{len(tiles)}...")
291
-
292
- # Extract tile from image
293
- tile = image.crop((
294
- tile_info['x'],
295
- tile_info['y'],
296
- tile_info['x'] + tile_info['w'],
297
- tile_info['y'] + tile_info['h']
298
- ))
299
-
300
- # Pad if necessary
301
- if tile_info['w_padded'] != tile_info['w'] or tile_info['h_padded'] != tile_info['h']:
302
- padded_tile = Image.new('RGB', (tile_info['w_padded'], tile_info['h_padded']))
303
- padded_tile.paste(tile, (0, 0))
304
- tile = padded_tile
305
-
306
- # Process tile with FLUX
307
- try:
308
- gen_tile = pipe(
309
- prompt=prompt,
310
- image=tile,
311
- strength=strength,
312
- num_inference_steps=steps,
313
- guidance_scale=guidance,
314
- height=tile_info['h_padded'],
315
- width=tile_info['w_padded'],
316
- generator=generator,
317
- ).images[0]
318
-
319
- # Crop back to original tile size if padded
320
- if tile_info['w_padded'] != tile_info['w'] or tile_info['h_padded'] != tile_info['h']:
321
- gen_tile = gen_tile.crop((0, 0, tile_info['w'], tile_info['h']))
322
-
323
- # Create blend mask if needed
324
- if overlap > 0 and (tile_info['edge_x'] or tile_info['edge_y']):
325
- mask = create_blend_mask(
326
- tile_info['w'],
327
- tile_info['h'],
328
- overlap,
329
- tile_info['edge_x'],
330
- tile_info['edge_y']
331
- )
332
-
333
- # Composite with blending
334
- output.paste(gen_tile, (tile_info['x'], tile_info['y']), mask)
335
- else:
336
- # Direct paste for first tile or no overlap
337
- output.paste(gen_tile, (tile_info['x'], tile_info['y']))
338
-
339
- except Exception as e:
340
- print(f"Error processing tile: {e}")
341
- # Fallback: paste original tile
342
- output.paste(tile, (tile_info['x'], tile_info['y']))
343
-
344
- return output
345
 
346
 
347
- @spaces.GPU(duration=120)
348
  def enhance_image(
349
  image_input,
350
  image_url,
@@ -357,223 +199,254 @@ def enhance_image(
357
  custom_prompt,
358
  progress=gr.Progress(track_tqdm=True),
359
  ):
360
- """Main enhancement function"""
361
  try:
362
- # Move models to GPU and convert to appropriate dtype
363
- pipe.to("cuda")
364
- pipe.to(torch.bfloat16)
365
-
366
- florence_model.to("cuda")
367
- florence_model.to(torch.float16)
368
 
369
  # Handle image input
370
  if image_input is not None:
371
  input_image = image_input
372
  elif image_url:
373
- input_image = load_image_from_url(image_url)
 
 
374
  else:
375
- raise gr.Error("Please provide an image (upload or URL)")
376
 
377
  if randomize_seed:
378
  seed = random.randint(0, MAX_SEED)
379
 
380
- true_input_image = input_image
381
 
382
- # Process input image
383
- input_image, w_original, h_original, was_resized = process_input(
384
  input_image, upscale_factor
385
  )
386
 
387
- # Generate caption if requested
388
  if use_generated_caption:
389
- gr.Info("πŸ” Generating image caption...")
390
- generated_caption = generate_caption(input_image)
391
- prompt = generated_caption
392
- print(f"Generated caption: {prompt}")
393
  else:
394
- prompt = custom_prompt if custom_prompt.strip() else ""
 
395
 
396
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
397
 
398
- gr.Info("πŸš€ Upscaling image...")
 
399
 
400
- # Initial upscale
401
- if USE_ESRGAN and upscale_factor == 4:
402
- if torch.cuda.is_available():
403
- esrgan_model.to("cuda")
404
- control_image = esrgan_upscale(input_image, upscale_factor)
405
- if torch.cuda.is_available():
406
- esrgan_model.to("cpu")
407
- else:
408
- w, h = input_image.size
409
- control_image = input_image.resize(
410
- (w * upscale_factor, h * upscale_factor),
411
- resample=Image.LANCZOS
412
- )
413
 
414
- # Tiled Flux Img2Img for refinement
415
- image = tiled_flux_img2img(
416
- pipe,
417
- prompt,
418
- control_image,
419
- denoising_strength,
420
- num_inference_steps,
421
- 1.0, # guidance_scale fixed to 1.0
422
- generator,
423
- tile_size=1024,
424
- overlap=64
425
- )
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  if was_resized:
428
- gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
429
- image = image.resize(
430
- (w_original * upscale_factor, h_original * upscale_factor),
431
- resample=Image.LANCZOS
432
  )
433
 
434
- # Resize input image to match output size for slider alignment
435
- resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
436
 
437
- # Move models back to CPU to release GPU
438
  pipe.to("cpu")
439
- florence_model.to("cpu")
440
  torch.cuda.empty_cache()
 
441
 
442
- return [resized_input, image]
443
 
444
  except Exception as e:
445
- # Ensure models are moved back to CPU even on error
446
  pipe.to("cpu")
447
- florence_model.to("cpu")
448
  torch.cuda.empty_cache()
449
- raise gr.Error(f"Enhancement failed: {str(e)}")
 
450
 
451
 
452
- # Create Gradio interface
453
- with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
454
  gr.HTML(f"""
455
  <div class="main-header">
456
  <h1>🎨 AI Image Upscaler</h1>
457
- <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
458
- <p>Currently running on <strong>{power_device}</strong></p>
459
  </div>
460
  """)
461
-
462
  with gr.Row():
463
  with gr.Column(scale=1):
464
  gr.HTML("<h3>πŸ“€ Input</h3>")
465
 
466
  with gr.Tabs():
467
- with gr.TabItem("πŸ“ Upload Image"):
468
  input_image = gr.Image(
469
  label="Upload Image",
470
  type="pil",
471
  height=200
472
  )
473
 
474
- with gr.TabItem("πŸ”— Image URL"):
475
  image_url = gr.Textbox(
476
  label="Image URL",
477
- placeholder="https://example.com/image.jpg",
478
- value=""
479
  )
480
 
481
- gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
482
-
483
  use_generated_caption = gr.Checkbox(
484
- label="Use AI-generated caption (Florence-2)",
485
- value=True,
486
- info="Generate detailed caption automatically"
487
  )
488
 
489
  custom_prompt = gr.Textbox(
490
  label="Custom Prompt (optional)",
491
- placeholder="Enter custom prompt or leave empty for generated caption",
492
  lines=2
493
  )
494
 
495
- gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
496
-
497
  upscale_factor = gr.Slider(
498
  label="Upscale Factor",
499
- minimum=1,
500
  maximum=4,
501
  step=1,
502
- value=2,
503
- info="How much to upscale the image"
504
  )
505
 
506
  num_inference_steps = gr.Slider(
507
- label="Number of Inference Steps",
508
- minimum=8,
509
- maximum=50,
510
  step=1,
511
- value=25,
512
- info="More steps = better quality but slower"
513
  )
514
 
515
  denoising_strength = gr.Slider(
516
- label="Denoising Strength",
517
- minimum=0.0,
518
- maximum=1.0,
519
  step=0.05,
520
  value=0.3,
521
- info="Controls how much the image is transformed"
522
  )
523
 
524
  with gr.Row():
525
- randomize_seed = gr.Checkbox(
526
- label="Randomize seed",
527
- value=True
528
- )
529
  seed = gr.Slider(
530
  label="Seed",
531
  minimum=0,
532
  maximum=MAX_SEED,
533
  step=1,
534
- value=42,
535
- interactive=True
536
  )
537
 
538
- enhance_btn = gr.Button(
539
- "πŸš€ Upscale Image",
540
- variant="primary",
541
- size="lg"
542
- )
543
-
544
  with gr.Column(scale=2):
545
- gr.HTML("<h3>πŸ“Š Results</h3>")
546
-
547
  result_slider = ImageSlider(
548
  type="pil",
549
  interactive=False,
550
- height=600,
551
- elem_id="result_slider",
552
  label=None
553
  )
554
-
555
- # Event handler
556
  enhance_btn.click(
557
  fn=enhance_image,
558
  inputs=[
559
- input_image,
560
- image_url,
561
- seed,
562
- randomize_seed,
563
- num_inference_steps,
564
- upscale_factor,
565
- denoising_strength,
566
- use_generated_caption,
567
- custom_prompt,
568
  ],
569
  outputs=[result_slider]
570
  )
571
 
572
  gr.HTML("""
573
- <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
574
- <p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
575
  </div>
576
  """)
577
 
578
  if __name__ == "__main__":
579
- demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
+ import gc
16
 
17
+ # Disable ESRGAN for ZeroGPU (saves memory and complexity)
18
+ USE_ESRGAN = False
 
 
 
 
 
 
19
 
20
  css = """
21
  #col-container {
 
30
 
31
  # Device setup
32
  power_device = "ZeroGPU"
33
+ device = "cpu" # Start on CPU
34
 
35
  # Get HuggingFace token
36
  huggingface_token = os.getenv("HF_TOKEN")
 
45
  token=huggingface_token,
46
  )
47
 
48
+ # Load Florence-2 model
49
  print("πŸ“₯ Loading Florence-2 model...")
50
  florence_model = AutoModelForCausalLM.from_pretrained(
51
  "microsoft/Florence-2-large",
52
+ torch_dtype=torch.float32,
53
  trust_remote_code=True,
54
  attn_implementation="eager"
55
+ ).to(device).eval()
56
+
57
  florence_processor = AutoProcessor.from_pretrained(
58
  "microsoft/Florence-2-large",
59
  trust_remote_code=True
60
  )
61
 
62
+ # Load FLUX pipeline
63
  print("πŸ“₯ Loading FLUX Img2Img...")
64
  pipe = FluxImg2ImgPipeline.from_pretrained(
65
  model_path,
66
+ torch_dtype=torch.float32
67
  )
68
+
69
+ # Enable memory optimizations
70
+ pipe.enable_model_cpu_offload()
71
  pipe.enable_vae_tiling()
72
  pipe.enable_vae_slicing()
73
+ pipe.vae.enable_tiling()
74
+ pipe.vae.enable_slicing()
75
 
76
  print("βœ… All models loaded successfully!")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  MAX_SEED = 1000000
79
+ MAX_PIXEL_BUDGET = 2048 * 2048 # Reduced for ZeroGPU stability
80
+
81
+
82
+ def truncate_caption(caption, max_tokens=70):
83
+ """Truncate caption to avoid CLIP token limit"""
84
+ words = caption.split()
85
+ truncated = []
86
+ current_length = 0
87
+
88
+ for word in words:
89
+ # Rough estimate: 1 word β‰ˆ 1.3 tokens
90
+ if current_length + len(word) * 1.3 > max_tokens:
91
+ break
92
+ truncated.append(word)
93
+ current_length += len(word) * 1.3
94
+
95
+ result = ' '.join(truncated)
96
+ if len(truncated) < len(words):
97
+ result += "..."
98
+ return result
99
 
100
 
101
  def make_multiple_16(n):
102
+ """Round to nearest multiple of 16"""
103
  return ((n + 15) // 16) * 16
104
 
105
 
106
  def generate_caption(image):
107
+ """Generate caption using Florence-2"""
108
  try:
109
+ # Keep on CPU for caption generation
 
 
 
110
  task_prompt = "<MORE_DETAILED_CAPTION>"
111
+
112
+ # Resize image if too large for captioning
113
+ if image.width > 1024 or image.height > 1024:
114
+ image.thumbnail((1024, 1024), Image.LANCZOS)
115
+
116
  inputs = florence_processor(
117
+ text=task_prompt,
118
  images=image,
119
  return_tensors="pt"
120
+ ).to(device)
121
+
122
+ with torch.no_grad():
123
+ generated_ids = florence_model.generate(
124
+ input_ids=inputs["input_ids"],
125
+ pixel_values=inputs["pixel_values"],
126
+ max_new_tokens=256, # Reduced from 1024
127
+ num_beams=1, # Reduced from 3
128
+ do_sample=False, # Faster without sampling
129
+ )
 
 
 
 
130
 
131
  generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
132
  parsed_answer = florence_processor.post_process_generation(
 
136
  )
137
 
138
  caption = parsed_answer[task_prompt]
139
+ # Truncate to avoid CLIP token limit
140
+ caption = truncate_caption(caption, max_tokens=70)
141
  return caption
142
+
143
  except Exception as e:
144
  print(f"Caption generation failed: {e}")
145
+ return "high quality detailed image"
146
 
147
 
148
  def process_input(input_image, upscale_factor):
149
+ """Process input image with size constraints"""
150
  w, h = input_image.size
151
  w_original, h_original = w, h
152
 
153
  was_resized = False
154
 
155
+ # Check pixel budget
156
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
157
+ gr.Info("Resizing input to fit within processing limits...")
158
+
159
+ target_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
160
+ scale = (target_pixels / (w * h)) ** 0.5
161
+
 
 
 
162
  new_w = make_multiple_16(int(w * scale))
163
  new_h = make_multiple_16(int(h * scale))
164
+
165
+ input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
166
  was_resized = True
167
 
168
+ # Ensure dimensions are multiples of 16
169
+ w, h = input_image.size
170
+ new_w = make_multiple_16(w)
171
+ new_h = make_multiple_16(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ if new_w != w or new_h != h:
174
+ padded = Image.new('RGB', (new_w, new_h))
175
+ padded.paste(input_image, (0, 0))
176
+ input_image = padded
 
 
 
 
 
177
 
178
+ return input_image, w_original, h_original, was_resized
179
 
180
 
181
+ def simple_upscale(image, factor):
182
+ """Simple LANCZOS upscaling"""
183
+ return image.resize(
184
+ (image.width * factor, image.height * factor),
185
+ Image.LANCZOS
186
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
+ @spaces.GPU(duration=90) # Reduced from 120
190
  def enhance_image(
191
  image_input,
192
  image_url,
 
199
  custom_prompt,
200
  progress=gr.Progress(track_tqdm=True),
201
  ):
202
+ """Main enhancement function optimized for ZeroGPU"""
203
  try:
204
+ # Clear cache at start
205
+ torch.cuda.empty_cache()
206
+ gc.collect()
 
 
 
207
 
208
  # Handle image input
209
  if image_input is not None:
210
  input_image = image_input
211
  elif image_url:
212
+ response = requests.get(image_url, stream=True)
213
+ response.raise_for_status()
214
+ input_image = Image.open(response.raw)
215
  else:
216
+ raise gr.Error("Please provide an image")
217
 
218
  if randomize_seed:
219
  seed = random.randint(0, MAX_SEED)
220
 
221
+ original_image = input_image.copy()
222
 
223
+ # Process and validate input
224
+ input_image, w_orig, h_orig, was_resized = process_input(
225
  input_image, upscale_factor
226
  )
227
 
228
+ # Generate or use caption (keep on CPU)
229
  if use_generated_caption:
230
+ gr.Info("Generating caption...")
231
+ prompt = generate_caption(input_image)
232
+ print(f"Caption: {prompt}")
 
233
  else:
234
+ prompt = custom_prompt.strip() if custom_prompt else "high quality image"
235
+ prompt = truncate_caption(prompt, max_tokens=70)
236
 
237
+ # Initial upscale with LANCZOS
238
+ gr.Info("Upscaling image...")
239
+ upscaled = simple_upscale(input_image, upscale_factor)
240
 
241
+ # Move pipeline to GPU only when needed
242
+ pipe.to("cuda")
243
 
244
+ # For large images, process in smaller chunks
245
+ w, h = upscaled.size
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # Determine if we need tiling based on size
248
+ need_tiling = (w > 1536 or h > 1536)
 
 
 
 
 
 
 
 
 
 
249
 
250
+ if need_tiling:
251
+ gr.Info("Processing large image in tiles...")
252
+ # Process center crop for now (to avoid timeout)
253
+ crop_size = min(1024, w, h)
254
+ left = (w - crop_size) // 2
255
+ top = (h - crop_size) // 2
256
+
257
+ cropped = upscaled.crop((left, top, left + crop_size, top + crop_size))
258
+
259
+ # Ensure dimensions are multiples of 16
260
+ crop_w = make_multiple_16(cropped.width)
261
+ crop_h = make_multiple_16(cropped.height)
262
+
263
+ if crop_w != cropped.width or crop_h != cropped.height:
264
+ padded_crop = Image.new('RGB', (crop_w, crop_h))
265
+ padded_crop.paste(cropped, (0, 0))
266
+ cropped = padded_crop
267
+
268
+ # Process with FLUX
269
+ with torch.inference_mode():
270
+ generator = torch.Generator(device="cuda").manual_seed(seed)
271
+
272
+ result_crop = pipe(
273
+ prompt=prompt,
274
+ image=cropped,
275
+ strength=denoising_strength,
276
+ num_inference_steps=num_inference_steps,
277
+ guidance_scale=1.0,
278
+ height=crop_h,
279
+ width=crop_w,
280
+ generator=generator,
281
+ ).images[0]
282
+
283
+ # Crop back if padded
284
+ if crop_w != cropped.width or crop_h != cropped.height:
285
+ result_crop = result_crop.crop((0, 0, cropped.width, cropped.height))
286
+
287
+ # Paste enhanced crop back
288
+ result = upscaled.copy()
289
+ result.paste(result_crop, (left, top))
290
+
291
+ else:
292
+ # Process entire image if small enough
293
+ # Ensure dimensions are multiples of 16
294
+ proc_w = make_multiple_16(w)
295
+ proc_h = make_multiple_16(h)
296
+
297
+ if proc_w != w or proc_h != h:
298
+ padded = Image.new('RGB', (proc_w, proc_h))
299
+ padded.paste(upscaled, (0, 0))
300
+ upscaled = padded
301
+
302
+ with torch.inference_mode():
303
+ generator = torch.Generator(device="cuda").manual_seed(seed)
304
+
305
+ result = pipe(
306
+ prompt=prompt,
307
+ image=upscaled,
308
+ strength=denoising_strength,
309
+ num_inference_steps=num_inference_steps,
310
+ guidance_scale=1.0,
311
+ height=proc_h,
312
+ width=proc_w,
313
+ generator=generator,
314
+ ).images[0]
315
+
316
+ # Crop back if padded
317
+ if proc_w != w or proc_h != h:
318
+ result = result.crop((0, 0, w, h))
319
+
320
+ # Final resize if needed
321
  if was_resized:
322
+ result = result.resize(
323
+ (w_orig * upscale_factor, h_orig * upscale_factor),
324
+ Image.LANCZOS
 
325
  )
326
 
327
+ # Prepare for slider
328
+ input_resized = original_image.resize(result.size, Image.LANCZOS)
329
 
330
+ # Clean up
331
  pipe.to("cpu")
 
332
  torch.cuda.empty_cache()
333
+ gc.collect()
334
 
335
+ return [input_resized, result]
336
 
337
  except Exception as e:
338
+ # Ensure cleanup on error
339
  pipe.to("cpu")
 
340
  torch.cuda.empty_cache()
341
+ gc.collect()
342
+ raise gr.Error(f"Processing failed: {str(e)}")
343
 
344
 
345
+ # Gradio Interface
346
+ with gr.Blocks(css=css) as demo:
347
  gr.HTML(f"""
348
  <div class="main-header">
349
  <h1>🎨 AI Image Upscaler</h1>
350
+ <p>Upscale images using Florence-2 + FLUX (Optimized for ZeroGPU)</p>
351
+ <p>Running on <strong>{power_device}</strong></p>
352
  </div>
353
  """)
354
+
355
  with gr.Row():
356
  with gr.Column(scale=1):
357
  gr.HTML("<h3>πŸ“€ Input</h3>")
358
 
359
  with gr.Tabs():
360
+ with gr.TabItem("Upload"):
361
  input_image = gr.Image(
362
  label="Upload Image",
363
  type="pil",
364
  height=200
365
  )
366
 
367
+ with gr.TabItem("URL"):
368
  image_url = gr.Textbox(
369
  label="Image URL",
370
+ placeholder="https://example.com/image.jpg"
 
371
  )
372
 
 
 
373
  use_generated_caption = gr.Checkbox(
374
+ label="Auto-generate caption",
375
+ value=True
 
376
  )
377
 
378
  custom_prompt = gr.Textbox(
379
  label="Custom Prompt (optional)",
380
+ placeholder="Override auto-caption if desired",
381
  lines=2
382
  )
383
 
 
 
384
  upscale_factor = gr.Slider(
385
  label="Upscale Factor",
386
+ minimum=2,
387
  maximum=4,
388
  step=1,
389
+ value=2
 
390
  )
391
 
392
  num_inference_steps = gr.Slider(
393
+ label="Quality (Steps)",
394
+ minimum=15,
395
+ maximum=30,
396
  step=1,
397
+ value=20,
398
+ info="Higher = better but slower"
399
  )
400
 
401
  denoising_strength = gr.Slider(
402
+ label="Enhancement Strength",
403
+ minimum=0.1,
404
+ maximum=0.5,
405
  step=0.05,
406
  value=0.3,
407
+ info="Higher = more changes"
408
  )
409
 
410
  with gr.Row():
411
+ randomize_seed = gr.Checkbox(label="Random seed", value=True)
 
 
 
412
  seed = gr.Slider(
413
  label="Seed",
414
  minimum=0,
415
  maximum=MAX_SEED,
416
  step=1,
417
+ value=42
 
418
  )
419
 
420
+ enhance_btn = gr.Button("πŸš€ Upscale", variant="primary", size="lg")
421
+
 
 
 
 
422
  with gr.Column(scale=2):
423
+ gr.HTML("<h3>πŸ“Š Result</h3>")
 
424
  result_slider = ImageSlider(
425
  type="pil",
426
  interactive=False,
427
+ height=500,
 
428
  label=None
429
  )
430
+
 
431
  enhance_btn.click(
432
  fn=enhance_image,
433
  inputs=[
434
+ input_image, image_url, seed, randomize_seed,
435
+ num_inference_steps, upscale_factor, denoising_strength,
436
+ use_generated_caption, custom_prompt
 
 
 
 
 
 
437
  ],
438
  outputs=[result_slider]
439
  )
440
 
441
  gr.HTML("""
442
+ <div style="margin-top: 1rem; padding: 0.5rem; background: #f0f0f0; border-radius: 8px;">
443
+ <small>⚑ Optimized for ZeroGPU: Max 2048x2048 output, simplified processing for stability</small>
444
  </div>
445
  """)
446
 
447
  if __name__ == "__main__":
448
+ demo.queue(max_size=3).launch(
449
+ share=False, # Don't use share on Spaces
450
+ server_name="0.0.0.0",
451
+ server_port=7860
452
+ )