comrender commited on
Commit
a49f337
Β·
verified Β·
1 Parent(s): b0a0a29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -320
app.py CHANGED
@@ -1,21 +1,20 @@
1
- import logging
2
  import random
3
  import warnings
4
- import os
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
10
- from transformers import AutoProcessor, AutoModelForCausalLM
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
- import gc
16
 
17
- # Disable ESRGAN for ZeroGPU (saves memory and complexity)
18
- USE_ESRGAN = False
 
19
 
20
  css = """
21
  #col-container {
@@ -28,14 +27,10 @@ css = """
28
  }
29
  """
30
 
31
- # Device setup
32
- power_device = "ZeroGPU"
33
- device = "cpu" # Start on CPU
34
-
35
  # Get HuggingFace token
36
  huggingface_token = os.getenv("HF_TOKEN")
37
 
38
- # Download FLUX model
39
  print("πŸ“₯ Downloading FLUX model...")
40
  model_path = snapshot_download(
41
  repo_id="black-forest-labs/FLUX.1-dev",
@@ -45,408 +40,318 @@ model_path = snapshot_download(
45
  token=huggingface_token,
46
  )
47
 
48
- # Load Florence-2 model
49
- print("πŸ“₯ Loading Florence-2 model...")
50
- florence_model = AutoModelForCausalLM.from_pretrained(
51
- "microsoft/Florence-2-large",
52
- torch_dtype=torch.float32,
53
- trust_remote_code=True,
54
- attn_implementation="eager"
55
- ).to(device).eval()
56
-
57
- florence_processor = AutoProcessor.from_pretrained(
58
- "microsoft/Florence-2-large",
59
- trust_remote_code=True
60
- )
61
-
62
- # Load FLUX pipeline
63
- print("πŸ“₯ Loading FLUX Img2Img...")
64
  pipe = FluxImg2ImgPipeline.from_pretrained(
65
  model_path,
66
- torch_dtype=torch.float32
 
67
  )
68
 
69
  # Enable memory optimizations
70
- pipe.enable_model_cpu_offload()
71
  pipe.enable_vae_tiling()
72
  pipe.enable_vae_slicing()
73
  pipe.vae.enable_tiling()
74
  pipe.vae.enable_slicing()
75
 
76
- print("βœ… All models loaded successfully!")
 
 
 
 
 
 
 
 
77
 
78
- MAX_SEED = 1000000
79
- MAX_PIXEL_BUDGET = 2048 * 2048 # Reduced for ZeroGPU stability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
81
 
82
- def truncate_caption(caption, max_tokens=70):
83
- """Truncate caption to avoid CLIP token limit"""
84
- words = caption.split()
85
- truncated = []
86
- current_length = 0
87
-
88
- for word in words:
89
- # Rough estimate: 1 word β‰ˆ 1.3 tokens
90
- if current_length + len(word) * 1.3 > max_tokens:
91
- break
92
- truncated.append(word)
93
- current_length += len(word) * 1.3
94
-
95
- result = ' '.join(truncated)
96
- if len(truncated) < len(words):
97
- result += "..."
98
- return result
99
 
100
 
101
  def make_multiple_16(n):
102
- """Round to nearest multiple of 16"""
103
  return ((n + 15) // 16) * 16
104
 
105
 
106
- def generate_caption(image):
107
- """Generate caption using Florence-2"""
108
- try:
109
- # Keep on CPU for caption generation
110
- task_prompt = "<MORE_DETAILED_CAPTION>"
111
-
112
- # Resize image if too large for captioning
113
- if image.width > 1024 or image.height > 1024:
114
- image.thumbnail((1024, 1024), Image.LANCZOS)
115
-
116
- inputs = florence_processor(
117
- text=task_prompt,
118
- images=image,
119
- return_tensors="pt"
120
- ).to(device)
121
-
122
- with torch.no_grad():
123
- generated_ids = florence_model.generate(
124
- input_ids=inputs["input_ids"],
125
- pixel_values=inputs["pixel_values"],
126
- max_new_tokens=256, # Reduced from 1024
127
- num_beams=1, # Reduced from 3
128
- do_sample=False, # Faster without sampling
129
- )
130
-
131
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
132
- parsed_answer = florence_processor.post_process_generation(
133
- generated_text,
134
- task=task_prompt,
135
- image_size=(image.width, image.height)
136
- )
137
-
138
- caption = parsed_answer[task_prompt]
139
- # Truncate to avoid CLIP token limit
140
- caption = truncate_caption(caption, max_tokens=70)
141
- return caption
142
-
143
- except Exception as e:
144
- print(f"Caption generation failed: {e}")
145
- return "high quality detailed image"
146
 
147
 
148
- def process_input(input_image, upscale_factor):
149
- """Process input image with size constraints"""
150
- w, h = input_image.size
151
- w_original, h_original = w, h
152
 
153
- was_resized = False
 
 
154
 
155
- # Check pixel budget
156
- if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
157
- gr.Info("Resizing input to fit within processing limits...")
158
-
159
- target_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
160
- scale = (target_pixels / (w * h)) ** 0.5
161
-
162
- new_w = make_multiple_16(int(w * scale))
163
- new_h = make_multiple_16(int(h * scale))
164
-
165
- input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
166
- was_resized = True
167
-
168
- # Ensure dimensions are multiples of 16
169
- w, h = input_image.size
170
- new_w = make_multiple_16(w)
171
- new_h = make_multiple_16(h)
172
-
173
- if new_w != w or new_h != h:
174
- padded = Image.new('RGB', (new_w, new_h))
175
- padded.paste(input_image, (0, 0))
176
- input_image = padded
177
-
178
- return input_image, w_original, h_original, was_resized
179
 
180
 
181
- def simple_upscale(image, factor):
182
- """Simple LANCZOS upscaling"""
183
- return image.resize(
184
- (image.width * factor, image.height * factor),
185
- Image.LANCZOS
186
- )
 
 
 
 
 
 
 
187
 
188
 
189
- @spaces.GPU(duration=90) # Reduced from 120
190
  def enhance_image(
191
- image_input,
192
- image_url,
193
  seed,
194
  randomize_seed,
195
  num_inference_steps,
196
- upscale_factor,
197
  denoising_strength,
198
- use_generated_caption,
199
- custom_prompt,
200
  progress=gr.Progress(track_tqdm=True),
201
  ):
202
- """Main enhancement function optimized for ZeroGPU"""
 
 
 
 
 
 
 
203
  try:
204
- # Clear cache at start
205
- torch.cuda.empty_cache()
206
- gc.collect()
207
-
208
- # Handle image input
209
- if image_input is not None:
210
- input_image = image_input
211
- elif image_url:
212
- response = requests.get(image_url, stream=True)
213
- response.raise_for_status()
214
- input_image = Image.open(response.raw)
215
- else:
216
- raise gr.Error("Please provide an image")
217
-
218
  if randomize_seed:
219
  seed = random.randint(0, MAX_SEED)
220
 
221
- original_image = input_image.copy()
 
222
 
223
- # Process and validate input
224
- input_image, w_orig, h_orig, was_resized = process_input(
225
- input_image, upscale_factor
226
- )
227
 
228
- # Generate or use caption (keep on CPU)
229
- if use_generated_caption:
230
- gr.Info("Generating caption...")
231
- prompt = generate_caption(input_image)
232
- print(f"Caption: {prompt}")
233
- else:
234
- prompt = custom_prompt.strip() if custom_prompt else "high quality image"
235
- prompt = truncate_caption(prompt, max_tokens=70)
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- # Initial upscale with LANCZOS
238
- gr.Info("Upscaling image...")
239
- upscaled = simple_upscale(input_image, upscale_factor)
 
240
 
241
- # Move pipeline to GPU only when needed
242
- pipe.to("cuda")
 
 
 
243
 
244
- # For large images, process in smaller chunks
245
- w, h = upscaled.size
246
 
247
- # Determine if we need tiling based on size
248
- need_tiling = (w > 1536 or h > 1536)
249
 
250
- if need_tiling:
251
- gr.Info("Processing large image in tiles...")
252
- # Process center crop for now (to avoid timeout)
253
- crop_size = min(1024, w, h)
254
- left = (w - crop_size) // 2
255
- top = (h - crop_size) // 2
256
-
257
- cropped = upscaled.crop((left, top, left + crop_size, top + crop_size))
258
-
259
- # Ensure dimensions are multiples of 16
260
- crop_w = make_multiple_16(cropped.width)
261
- crop_h = make_multiple_16(cropped.height)
262
-
263
- if crop_w != cropped.width or crop_h != cropped.height:
264
- padded_crop = Image.new('RGB', (crop_w, crop_h))
265
- padded_crop.paste(cropped, (0, 0))
266
- cropped = padded_crop
267
-
268
- # Process with FLUX
269
- with torch.inference_mode():
270
- generator = torch.Generator(device="cuda").manual_seed(seed)
271
-
272
- result_crop = pipe(
273
- prompt=prompt,
274
- image=cropped,
275
- strength=denoising_strength,
276
- num_inference_steps=num_inference_steps,
277
- guidance_scale=1.0,
278
- height=crop_h,
279
- width=crop_w,
280
- generator=generator,
281
- ).images[0]
282
-
283
- # Crop back if padded
284
- if crop_w != cropped.width or crop_h != cropped.height:
285
- result_crop = result_crop.crop((0, 0, cropped.width, cropped.height))
286
-
287
- # Paste enhanced crop back
288
- result = upscaled.copy()
289
- result.paste(result_crop, (left, top))
290
-
291
- else:
292
- # Process entire image if small enough
293
- # Ensure dimensions are multiples of 16
294
- proc_w = make_multiple_16(w)
295
- proc_h = make_multiple_16(h)
296
-
297
- if proc_w != w or proc_h != h:
298
- padded = Image.new('RGB', (proc_w, proc_h))
299
- padded.paste(upscaled, (0, 0))
300
- upscaled = padded
301
-
302
- with torch.inference_mode():
303
- generator = torch.Generator(device="cuda").manual_seed(seed)
304
-
305
- result = pipe(
306
- prompt=prompt,
307
- image=upscaled,
308
- strength=denoising_strength,
309
- num_inference_steps=num_inference_steps,
310
- guidance_scale=1.0,
311
- height=proc_h,
312
- width=proc_w,
313
- generator=generator,
314
- ).images[0]
315
-
316
- # Crop back if padded
317
- if proc_w != w or proc_h != h:
318
- result = result.crop((0, 0, w, h))
319
 
320
- # Final resize if needed
321
- if was_resized:
322
- result = result.resize(
323
- (w_orig * upscale_factor, h_orig * upscale_factor),
324
- Image.LANCZOS
325
- )
 
 
 
 
 
326
 
327
- # Prepare for slider
328
- input_resized = original_image.resize(result.size, Image.LANCZOS)
 
329
 
330
- # Clean up
331
  pipe.to("cpu")
332
  torch.cuda.empty_cache()
333
  gc.collect()
334
 
335
- return [input_resized, result]
 
 
 
 
336
 
337
  except Exception as e:
338
- # Ensure cleanup on error
339
  pipe.to("cpu")
 
340
  torch.cuda.empty_cache()
341
  gc.collect()
342
- raise gr.Error(f"Processing failed: {str(e)}")
343
 
344
 
345
- # Gradio Interface
346
  with gr.Blocks(css=css) as demo:
347
- gr.HTML(f"""
348
  <div class="main-header">
349
- <h1>🎨 AI Image Upscaler</h1>
350
- <p>Upscale images using Florence-2 + FLUX (Optimized for ZeroGPU)</p>
351
- <p>Running on <strong>{power_device}</strong></p>
352
  </div>
353
  """)
354
 
355
  with gr.Row():
356
  with gr.Column(scale=1):
357
- gr.HTML("<h3>πŸ“€ Input</h3>")
358
-
359
- with gr.Tabs():
360
- with gr.TabItem("Upload"):
361
- input_image = gr.Image(
362
- label="Upload Image",
363
- type="pil",
364
- height=200
365
- )
366
-
367
- with gr.TabItem("URL"):
368
- image_url = gr.Textbox(
369
- label="Image URL",
370
- placeholder="https://example.com/image.jpg"
371
- )
372
-
373
- use_generated_caption = gr.Checkbox(
374
- label="Auto-generate caption",
375
- value=True
376
  )
377
 
378
- custom_prompt = gr.Textbox(
379
- label="Custom Prompt (optional)",
380
- placeholder="Override auto-caption if desired",
 
381
  lines=2
382
  )
383
 
384
- upscale_factor = gr.Slider(
385
- label="Upscale Factor",
386
- minimum=2,
387
- maximum=4,
388
- step=1,
389
- value=2
390
- )
391
-
392
- num_inference_steps = gr.Slider(
393
- label="Quality (Steps)",
394
- minimum=15,
395
- maximum=30,
396
- step=1,
397
- value=20,
398
- info="Higher = better but slower"
399
- )
400
-
401
- denoising_strength = gr.Slider(
402
- label="Enhancement Strength",
403
- minimum=0.1,
404
- maximum=0.5,
405
- step=0.05,
406
- value=0.3,
407
- info="Higher = more changes"
408
- )
409
-
410
- with gr.Row():
411
- randomize_seed = gr.Checkbox(label="Random seed", value=True)
412
- seed = gr.Slider(
413
- label="Seed",
414
- minimum=0,
415
- maximum=MAX_SEED,
416
  step=1,
417
- value=42
 
 
 
 
 
 
 
 
 
 
418
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- enhance_btn = gr.Button("πŸš€ Upscale", variant="primary", size="lg")
 
 
 
 
421
 
422
  with gr.Column(scale=2):
423
- gr.HTML("<h3>πŸ“Š Result</h3>")
424
  result_slider = ImageSlider(
425
  type="pil",
 
 
 
 
 
 
 
426
  interactive=False,
427
- height=500,
428
- label=None
429
  )
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  enhance_btn.click(
432
  fn=enhance_image,
433
  inputs=[
434
- input_image, image_url, seed, randomize_seed,
435
- num_inference_steps, upscale_factor, denoising_strength,
436
- use_generated_caption, custom_prompt
 
 
 
437
  ],
438
- outputs=[result_slider]
439
  )
440
 
441
  gr.HTML("""
442
- <div style="margin-top: 1rem; padding: 0.5rem; background: #f0f0f0; border-radius: 8px;">
443
- <small>⚑ Optimized for ZeroGPU: Max 2048x2048 output, simplified processing for stability</small>
 
444
  </div>
445
  """)
446
 
447
  if __name__ == "__main__":
448
  demo.queue(max_size=3).launch(
449
- share=False, # Don't use share on Spaces
450
  server_name="0.0.0.0",
451
  server_port=7860
452
  )
 
1
+ import os
2
  import random
3
  import warnings
4
+ import gc
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
 
10
  from gradio_imageslider import ImageSlider
11
  from PIL import Image
12
  from huggingface_hub import snapshot_download
13
  import requests
 
14
 
15
+ # ESRGAN imports
16
+ from basicsr.archs.rrdbnet_arch import RRDBNet
17
+ from basicsr.utils import img2tensor, tensor2img
18
 
19
  css = """
20
  #col-container {
 
27
  }
28
  """
29
 
 
 
 
 
30
  # Get HuggingFace token
31
  huggingface_token = os.getenv("HF_TOKEN")
32
 
33
+ # Download FLUX model if not already cached
34
  print("πŸ“₯ Downloading FLUX model...")
35
  model_path = snapshot_download(
36
  repo_id="black-forest-labs/FLUX.1-dev",
 
40
  token=huggingface_token,
41
  )
42
 
43
+ # Load FLUX pipeline on CPU initially
44
+ print("πŸ“₯ Loading FLUX Img2Img pipeline...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  pipe = FluxImg2ImgPipeline.from_pretrained(
46
  model_path,
47
+ torch_dtype=torch.bfloat16,
48
+ use_safetensors=True
49
  )
50
 
51
  # Enable memory optimizations
 
52
  pipe.enable_vae_tiling()
53
  pipe.enable_vae_slicing()
54
  pipe.vae.enable_tiling()
55
  pipe.vae.enable_slicing()
56
 
57
+ # Download and load ESRGAN 4x-UltraSharp model
58
+ print("πŸ“₯ Loading ESRGAN 4x-UltraSharp...")
59
+ esrgan_path = "4x-UltraSharp.pth"
60
+ if not os.path.exists(esrgan_path):
61
+ print("Downloading ESRGAN model...")
62
+ url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
63
+ response = requests.get(url)
64
+ with open(esrgan_path, "wb") as f:
65
+ f.write(response.content)
66
 
67
+ # Initialize ESRGAN model
68
+ esrgan_model = RRDBNet(
69
+ num_in_ch=3,
70
+ num_out_ch=3,
71
+ num_feat=64,
72
+ num_block=23,
73
+ num_grow_ch=32,
74
+ scale=4
75
+ )
76
+ state_dict = torch.load(esrgan_path, map_location='cpu')
77
+ if 'params_ema' in state_dict:
78
+ state_dict = state_dict['params_ema']
79
+ elif 'params' in state_dict:
80
+ state_dict = state_dict['params']
81
+ esrgan_model.load_state_dict(state_dict)
82
+ esrgan_model.eval()
83
 
84
+ print("βœ… All models loaded successfully!")
85
 
86
+ MAX_SEED = 1000000
87
+ MAX_INPUT_SIZE = 512 # Max input size before upscaling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  def make_multiple_16(n):
91
+ """Round to nearest multiple of 16 for FLUX requirements"""
92
  return ((n + 15) // 16) * 16
93
 
94
 
95
+ def truncate_prompt(prompt, max_tokens=75):
96
+ """Truncate prompt to avoid CLIP token limit (77 tokens)"""
97
+ if not prompt:
98
+ return ""
99
+
100
+ # Simple truncation by character count (rough approximation)
101
+ if len(prompt) > 250: # ~75 tokens
102
+ return prompt[:250] + "..."
103
+ return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
+ def prepare_image(image, max_size=MAX_INPUT_SIZE):
107
+ """Prepare image for processing"""
108
+ w, h = image.size
 
109
 
110
+ # Limit input size
111
+ if w > max_size or h > max_size:
112
+ image.thumbnail((max_size, max_size), Image.LANCZOS)
113
 
114
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
+ def esrgan_upscale(image):
118
+ """Upscale image 4x using ESRGAN"""
119
+ # Convert PIL to tensor
120
+ img_np = np.array(image).astype(np.float32) / 255.
121
+ img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
122
+
123
+ # Upscale
124
+ with torch.no_grad():
125
+ output = esrgan_model(img_tensor.unsqueeze(0).cpu())
126
+
127
+ # Convert back to PIL
128
+ output_np = tensor2img(output.squeeze(0), rgb2bgr=False, min_max=(0, 1))
129
+ return Image.fromarray(output_np)
130
 
131
 
132
+ @spaces.GPU(duration=60) # 60 seconds should be enough
133
  def enhance_image(
134
+ input_image,
135
+ prompt,
136
  seed,
137
  randomize_seed,
138
  num_inference_steps,
 
139
  denoising_strength,
 
 
140
  progress=gr.Progress(track_tqdm=True),
141
  ):
142
+ """Main enhancement function"""
143
+ if input_image is None:
144
+ raise gr.Error("Please upload an image")
145
+
146
+ # Clear memory
147
+ torch.cuda.empty_cache()
148
+ gc.collect()
149
+
150
  try:
151
+ # Randomize seed if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
 
155
+ # Prepare and validate prompt
156
+ prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed")
157
 
158
+ # Prepare input image
159
+ input_image = prepare_image(input_image)
160
+ original_size = input_image.size
 
161
 
162
+ # Step 1: ESRGAN upscale (4x) on CPU
163
+ gr.Info("πŸ” Upscaling with ESRGAN 4x...")
164
+ with torch.no_grad():
165
+ # Move ESRGAN to GPU for faster processing
166
+ esrgan_model.to("cuda")
167
+
168
+ # Convert image for ESRGAN
169
+ img_np = np.array(input_image).astype(np.float32) / 255.
170
+ img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
171
+ img_tensor = img_tensor.unsqueeze(0).to("cuda")
172
+
173
+ # Upscale
174
+ output_tensor = esrgan_model(img_tensor)
175
+
176
+ # Convert back to PIL
177
+ output_np = tensor2img(output_tensor.squeeze(0).cpu(), rgb2bgr=False, min_max=(0, 1))
178
+ upscaled_image = Image.fromarray(output_np)
179
+
180
+ # Move ESRGAN back to CPU to free memory
181
+ esrgan_model.to("cpu")
182
+ torch.cuda.empty_cache()
183
 
184
+ # Ensure dimensions are multiples of 16 for FLUX
185
+ w, h = upscaled_image.size
186
+ new_w = make_multiple_16(w)
187
+ new_h = make_multiple_16(h)
188
 
189
+ if new_w != w or new_h != h:
190
+ # Pad image to meet requirements
191
+ padded = Image.new('RGB', (new_w, new_h))
192
+ padded.paste(upscaled_image, (0, 0))
193
+ upscaled_image = padded
194
 
195
+ # Step 2: FLUX enhancement
196
+ gr.Info("🎨 Enhancing with FLUX...")
197
 
198
+ # Move pipeline to GPU
199
+ pipe.to("cuda")
200
 
201
+ # Generate with FLUX
202
+ generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ with torch.inference_mode():
205
+ result = pipe(
206
+ prompt=prompt,
207
+ image=upscaled_image,
208
+ strength=denoising_strength,
209
+ num_inference_steps=num_inference_steps,
210
+ guidance_scale=1.0, # Fixed at 1.0 for FLUX dev
211
+ height=new_h,
212
+ width=new_w,
213
+ generator=generator,
214
+ ).images[0]
215
 
216
+ # Crop back if we padded
217
+ if new_w != w or new_h != h:
218
+ result = result.crop((0, 0, w, h))
219
 
220
+ # Move pipeline back to CPU
221
  pipe.to("cpu")
222
  torch.cuda.empty_cache()
223
  gc.collect()
224
 
225
+ # Prepare images for slider (before/after)
226
+ input_resized = input_image.resize(result.size, Image.LANCZOS)
227
+
228
+ gr.Info("βœ… Enhancement complete!")
229
+ return [input_resized, result], seed
230
 
231
  except Exception as e:
232
+ # Cleanup on error
233
  pipe.to("cpu")
234
+ esrgan_model.to("cpu")
235
  torch.cuda.empty_cache()
236
  gc.collect()
237
+ raise gr.Error(f"Enhancement failed: {str(e)}")
238
 
239
 
240
+ # Create Gradio interface
241
  with gr.Blocks(css=css) as demo:
242
+ gr.HTML("""
243
  <div class="main-header">
244
+ <h1>πŸš€ ESRGAN 4x + FLUX Enhancement</h1>
245
+ <p>Upload an image to upscale 4x with ESRGAN and enhance with FLUX</p>
246
+ <p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β†’ Output: 2048x2048</p>
247
  </div>
248
  """)
249
 
250
  with gr.Row():
251
  with gr.Column(scale=1):
252
+ # Input section
253
+ input_image = gr.Image(
254
+ label="Input Image",
255
+ type="pil",
256
+ height=256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
 
259
+ prompt = gr.Textbox(
260
+ label="Enhancement Prompt",
261
+ placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')",
262
+ value="high quality, ultra detailed, sharp",
263
  lines=2
264
  )
265
 
266
+ with gr.Accordion("Advanced Settings", open=False):
267
+ num_inference_steps = gr.Slider(
268
+ label="Enhancement Steps",
269
+ minimum=10,
270
+ maximum=25,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  step=1,
272
+ value=18,
273
+ info="More steps = better quality but slower"
274
+ )
275
+
276
+ denoising_strength = gr.Slider(
277
+ label="Enhancement Strength",
278
+ minimum=0.1,
279
+ maximum=0.6,
280
+ step=0.05,
281
+ value=0.35,
282
+ info="Higher = more changes to the image"
283
  )
284
+
285
+ with gr.Row():
286
+ randomize_seed = gr.Checkbox(
287
+ label="Randomize seed",
288
+ value=True
289
+ )
290
+ seed = gr.Slider(
291
+ label="Seed",
292
+ minimum=0,
293
+ maximum=MAX_SEED,
294
+ step=1,
295
+ value=42
296
+ )
297
 
298
+ enhance_btn = gr.Button(
299
+ "🎨 Enhance Image (4x Upscale)",
300
+ variant="primary",
301
+ size="lg"
302
+ )
303
 
304
  with gr.Column(scale=2):
305
+ # Output section
306
  result_slider = ImageSlider(
307
  type="pil",
308
+ label="Before / After",
309
+ interactive=False,
310
+ height=512
311
+ )
312
+
313
+ used_seed = gr.Number(
314
+ label="Seed Used",
315
  interactive=False,
316
+ visible=False
 
317
  )
318
 
319
+ # Examples
320
+ gr.Examples(
321
+ examples=[
322
+ ["high quality, ultra detailed, sharp"],
323
+ ["cinematic, professional photography, enhanced details"],
324
+ ["vibrant colors, high contrast, sharp focus"],
325
+ ["photorealistic, 8k quality, fine details"],
326
+ ],
327
+ inputs=[prompt],
328
+ label="Example Prompts"
329
+ )
330
+
331
+ # Event handler
332
  enhance_btn.click(
333
  fn=enhance_image,
334
  inputs=[
335
+ input_image,
336
+ prompt,
337
+ seed,
338
+ randomize_seed,
339
+ num_inference_steps,
340
+ denoising_strength,
341
  ],
342
+ outputs=[result_slider, used_seed]
343
  )
344
 
345
  gr.HTML("""
346
+ <div style="margin-top: 2rem; text-align: center; color: #666;">
347
+ <p>πŸ“Œ Pipeline: ESRGAN 4x-UltraSharp β†’ FLUX Dev Enhancement</p>
348
+ <p>⚑ Optimized for ZeroGPU with automatic memory management</p>
349
  </div>
350
  """)
351
 
352
  if __name__ == "__main__":
353
  demo.queue(max_size=3).launch(
354
+ share=False,
355
  server_name="0.0.0.0",
356
  server_port=7860
357
  )