comrender commited on
Commit
a1ef78c
Β·
verified Β·
1 Parent(s): 58d1893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -318
app.py CHANGED
@@ -40,17 +40,23 @@ if torch.cuda.is_available():
40
  device = "cuda"
41
  else:
42
  power_device = "CPU"
43
- device = "cpu"# Get HuggingFace token
44
- huggingface_token = os.getenv("HF_TOKEN")# Download FLUX model
45
- print(" Downloading FLUX model...")
 
 
 
 
46
  model_path = snapshot_download(
47
  repo_id="black-forest-labs/FLUX.1-dev",
48
  repo_type="model",
49
  ignore_patterns=["*.md", "*.gitattributes"],
50
  local_dir="FLUX.1-dev",
51
  token=huggingface_token,
52
- )# Load Florence-2 model for image captioning
53
- print(" Loading Florence-2 model...")
 
 
54
  florence_model = AutoModelForCausalLM.from_pretrained(
55
  "microsoft/Florence-2-large",
56
  torch_dtype=torch.float16,
@@ -60,15 +66,21 @@ florence_model = AutoModelForCausalLM.from_pretrained(
60
  florence_processor = AutoProcessor.from_pretrained(
61
  "microsoft/Florence-2-large",
62
  trust_remote_code=True
63
- )# Load FLUX Img2Img pipeline
64
- print(" Loading FLUX Img2Img...")
 
 
65
  pipe = FluxImg2ImgPipeline.from_pretrained(
66
  model_path,
67
  torch_dtype=torch.bfloat16
68
  )
69
  pipe.to(device)
70
  pipe.enable_vae_tiling()
71
- pipe.enable_vae_slicing()print(" All models loaded successfully!")# Download ESRGAN model if using
 
 
 
 
72
  if USE_ESRGAN:
73
  esrgan_path = "4x-UltraSharp.pth"
74
  if not os.path.exists(esrgan_path):
@@ -79,124 +91,149 @@ if USE_ESRGAN:
79
  state_dict = torch.load(esrgan_path)['params_ema']
80
  esrgan_model.load_state_dict(state_dict)
81
  esrgan_model.eval()
82
- esrgan_model.to(device)MAX_SEED = 1000000
83
- MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling supportdef generate_caption(image):
 
 
 
 
 
84
  """Generate detailed caption using Florence-2"""
85
  try:
86
  task_prompt = "<MORE_DETAILED_CAPTION>"
87
- prompt = task_prompt inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
88
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) # Match model dtype
89
-
90
- generated_ids = florence_model.generate(
91
- input_ids=inputs["input_ids"],
92
- pixel_values=inputs["pixel_values"],
93
- max_new_tokens=1024,
94
- num_beams=3,
95
- do_sample=True,
96
- )
97
-
98
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
99
- parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
100
-
101
- caption = parsed_answer[task_prompt]
102
- return caption
103
- except Exception as e:
104
- print(f"Caption generation failed: {e}")
105
- return "a high quality detailed image"def process_input(input_image, upscale_factor):
 
 
 
 
 
106
  """Process input image and handle size constraints"""
107
  w, h = input_image.size
108
  w_original, h_original = w, h
109
- aspect_ratio = w / hwas_resized = False
110
 
111
- if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
112
- warnings.warn(
113
- f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
114
- )
115
- gr.Info(
116
- f"Requested output image is too large. Resizing input to fit within pixel budget."
117
- )
118
- target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
119
- scale = (target_input_pixels / (w * h)) ** 0.5
120
- new_w = int(w * scale) - int(w * scale) % 8
121
- new_h = int(h * scale) - int(h * scale) % 8
122
- input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
123
- was_resized = True
124
-
125
- return input_image, w_original, h_original, was_resizeddef load_image_from_url(url):
 
 
 
 
 
126
  """Load image from URL"""
127
  try:
128
  response = requests.get(url, stream=True)
129
  response.raise_for_status()
130
  return Image.open(response.raw)
131
  except Exception as e:
132
- raise gr.Error(f"Failed to load image from URL: {e}")def esrgan_upscale(image, scale=4):
 
 
 
133
  if not USE_ESRGAN:
134
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
135
  img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
136
  with torch.no_grad():
137
  output = esrgan_model(img.unsqueeze(0)).squeeze()
138
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
139
- return Image.fromarray(output_img)def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
 
 
 
140
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
141
  w, h = image.size
142
- output = image.copy() # Start with the control image# For handling long prompts: truncate for CLIP, full for T5
143
- max_clip_tokens = pipe.tokenizer.model_max_length # Typically 77
144
- input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
145
- if input_ids.shape[1] > max_clip_tokens:
146
- input_ids = input_ids[:, :max_clip_tokens]
147
- prompt_clip = pipe.tokenizer.decode(input_ids[0], skip_special_tokens=True)
148
- else:
149
- prompt_clip = prompt
150
-
151
- for x in range(0, w, tile_size - overlap):
152
- for y in range(0, h, tile_size - overlap):
153
- tile_w = min(tile_size, w - x)
154
- tile_h = min(tile_size, h - y)
155
- tile = image.crop((x, y, x + tile_w, y + tile_h))
156
-
157
- # Run Flux on tile
158
- gen_tile = pipe(
159
- prompt=prompt_clip,
160
- prompt_2=prompt,
161
- image=tile,
162
- strength=strength,
163
- num_inference_steps=steps,
164
- guidance_scale=guidance,
165
- height=tile_h,
166
- width=tile_w,
167
- generator=generator,
168
- ).images[0]
169
-
170
- # Resize back to exact tile size if pipeline adjusted it
171
- gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
172
-
173
- # Paste with blending if overlap
174
- if overlap > 0:
175
- paste_box = (x, y, x + tile_w, y + tile_h)
176
- if x > 0 or y > 0:
177
- # Simple linear blend on overlaps
178
- mask = Image.new('L', (tile_w, tile_h), 255)
179
- if x > 0:
180
- blend_width = min(overlap, tile_w)
181
- for i in range(blend_width):
182
- for j in range(tile_h):
183
- mask.putpixel((i, j), int(255 * (i / overlap)))
184
- if y > 0:
185
- blend_height = min(overlap, tile_h)
186
- for i in range(tile_w):
187
- for j in range(blend_height):
188
- mask.putpixel((i, j), int(255 * (j / overlap)))
189
- output.paste(gen_tile, paste_box, mask)
 
 
 
 
190
  else:
191
- output.paste(gen_tile, paste_box)
192
- else:
193
- output.paste(gen_tile, (x, y))
194
 
195
- return output@spaces.GPU(duration=120)
 
 
 
196
  def enhance_image(
197
  image_input,
198
  image_url,
199
- seed,
200
  randomize_seed,
201
  num_inference_steps,
202
  upscale_factor,
@@ -212,244 +249,259 @@ def enhance_image(
212
  elif image_url:
213
  input_image = load_image_from_url(image_url)
214
  else:
215
- raise gr.Error("Please provide an image (upload or URL)")if randomize_seed:
216
- seed = random.randint(0, MAX_SEED)
217
 
218
- true_input_image = input_image
 
 
 
 
219
 
220
- # Process input image
221
- input_image, w_original, h_original, was_resized = process_input(
222
- input_image, upscale_factor
223
- )
224
 
225
- # Generate caption if requested
226
- if use_generated_caption:
227
- gr.Info(" Generating image caption...")
228
- generated_caption = generate_caption(input_image)
229
- prompt = generated_caption
230
- else:
231
- prompt = custom_prompt if custom_prompt.strip() else ""
232
 
233
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
234
 
235
- gr.Info(" Upscaling image...")
236
 
237
- # Initial upscale
238
- if USE_ESRGAN and upscale_factor == 4:
239
- control_image = esrgan_upscale(input_image, upscale_factor)
240
- else:
241
- w, h = input_image.size
242
- control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
243
 
244
- # Tiled Flux Img2Img for refinement
245
- image = tiled_flux_img2img(
246
- pipe,
247
- prompt,
248
- control_image,
249
- denoising_strength,
250
- num_inference_steps,
251
- 1.0, # Hardcoded guidance_scale to 1
252
- generator,
253
- tile_size=1024,
254
- overlap=32
255
- )
 
 
 
 
 
 
 
256
 
257
- if was_resized:
258
- gr.Info(f" Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
259
- image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
 
 
 
 
 
260
 
261
- # Resize input image to match output size for slider alignment
262
- resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
263
 
264
- return [resized_input, image]# Create Gradio interface
265
- with gr.Blocks(css=css, title=" AI Image Upscaler - Florence-2 + FLUX") as demo:
266
  gr.HTML("""
267
  <div class="main-header">
268
- <h1>Flux Dev Ultimate HD Upscaler</h1>
269
- <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
270
  <p>Currently running on <strong>{}</strong></p>
271
  </div>
272
- """.format(power_device))with gr.Row():
273
- with gr.Column(scale=1):
274
- gr.HTML("<h3> Input</h3>")
275
-
276
- with gr.Tabs():
277
- with gr.TabItem(" Upload Image"):
278
- input_image = gr.Image(
279
- label="Upload Image",
280
- type="pil",
281
- height=200 # Made smaller
282
- )
283
 
284
- with gr.TabItem(" Image URL"):
285
- image_url = gr.Textbox(
286
- label="Image URL",
287
- placeholder="https://example.com/image.jpg",
288
- value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
289
- )
290
-
291
- gr.HTML("<h3> Caption Settings</h3>")
292
-
293
- use_generated_caption = gr.Checkbox(
294
- label="Use AI-generated caption (Florence-2)",
295
- value=True,
296
- info="Generate detailed caption automatically"
297
- )
298
-
299
- custom_prompt = gr.Textbox(
300
- label="Custom Prompt (optional)",
301
- placeholder="Enter custom prompt or leave empty for generated caption",
302
- lines=2
303
- )
304
-
305
- gr.HTML("<h3> Upscaling Settings</h3>")
306
-
307
- upscale_factor = gr.Slider(
308
- label="Upscale Factor",
309
- minimum=1,
310
- maximum=4,
311
- step=1,
312
- value=2,
313
- info="How much to upscale the image"
314
- )
315
-
316
- num_inference_steps = gr.Slider(
317
- label="Steps (25 Recommended)",
318
- minimum=8,
319
- maximum=50,
320
- step=1,
321
- value=25,
322
- info="More steps = better quality but slower"
323
- )
324
-
325
- denoising_strength = gr.Slider(
326
- label="Creativity (Denoising value)",
327
- minimum=0.0,
328
- maximum=1.0,
329
- step=0.05,
330
- value=0.3,
331
- info="More>0.3 = Very Creative, Less<0.1 = More consistent, 0.15-0.3 recommended"
332
- )
333
-
334
- with gr.Row():
335
- randomize_seed = gr.Checkbox(
336
- label="Randomize seed",
337
- value=True
338
  )
339
- seed = gr.Slider(
340
- label="Seed",
341
- minimum=0,
342
- maximum=MAX_SEED,
 
 
 
343
  step=1,
344
- value=42,
345
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  )
347
-
348
- enhance_btn = gr.Button(
349
- " Upscale Image",
350
- variant="primary",
351
- size="lg"
352
- )
353
 
354
- with gr.Column(scale=2): # Larger scale for results
355
- gr.HTML("<h3> Results</h3>")
356
-
357
- result_slider = ImageSlider(
358
- type="pil",
359
- interactive=False, # Disable interactivity to prevent uploads
360
- height=600, # Made larger
361
- elem_id="result_slider",
362
- label=None # Remove default label
363
- )
364
 
365
- # Event handler
366
- enhance_btn.click(
367
- fn=enhance_image,
368
- inputs=[
369
- input_image,
370
- image_url,
371
- seed,
372
- randomize_seed,
373
- num_inference_steps,
374
- upscale_factor,
375
- denoising_strength,
376
- use_generated_caption,
377
- custom_prompt,
378
- ],
379
- outputs=[result_slider]
380
- )
381
 
382
- gr.HTML("""
383
- <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
384
- <p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
385
- </div>
386
- """)
387
-
388
- # Custom CSS for slider
389
- gr.HTML("""
390
- <style>
391
- #result_slider .slider {
392
- width: 100% !important;
393
- max-width: inherit !important;
394
- }
395
- #result_slider img {
396
- object-fit: contain !important;
397
- width: 100% !important;
398
- height: auto !important;
399
- }
400
- #result_slider .gr-button-tool {
401
- display: none !important;
402
- }
403
- #result_slider .gr-button-undo {
404
- display: none !important;
405
- }
406
- #result_slider .gr-button-clear {
407
- display: none !important;
408
- }
409
- #result_slider .badge-container .badge {
410
- display: none !important;
411
- }
412
- #result_slider .badge-container::before {
413
- content: "Before";
414
- position: absolute;
415
- top: 10px;
416
- left: 10px;
417
- background: rgba(0,0,0,0.5);
418
- color: white;
419
- padding: 5px;
420
- border-radius: 5px;
421
- z-index: 10;
422
- }
423
- #result_slider .badge-container::after {
424
- content: "After";
425
- position: absolute;
426
- top: 10px;
427
- right: 10px;
428
- background: rgba(0,0,0,0.5);
429
- color: white;
430
- padding: 5px;
431
- border-radius: 5px;
432
- z-index: 10;
433
- }
434
- #result_slider .fullscreen img {
435
- object-fit: contain !important;
436
- width: 100vw !important;
437
- height: 100vh !important;
438
- }
439
- </style>
440
- """)
441
-
442
- # JS to set slider default position to middle
443
- gr.HTML("""
444
- <script>
445
- document.addEventListener('DOMContentLoaded', function() {
446
- const sliderInput = document.querySelector('#result_slider input[type="range"]');
447
- if (sliderInput) {
448
- sliderInput.value = 50;
449
- sliderInput.dispatchEvent(new Event('input'));
450
  }
451
- });
452
- </script>
453
- """)if __name__ == "__main__":
454
- demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
 
 
 
40
  device = "cuda"
41
  else:
42
  power_device = "CPU"
43
+ device = "cpu"
44
+
45
+ # Get HuggingFace token
46
+ huggingface_token = os.getenv("HF_TOKEN")
47
+
48
+ # Download FLUX model
49
+ print("πŸ“₯ Downloading FLUX model...")
50
  model_path = snapshot_download(
51
  repo_id="black-forest-labs/FLUX.1-dev",
52
  repo_type="model",
53
  ignore_patterns=["*.md", "*.gitattributes"],
54
  local_dir="FLUX.1-dev",
55
  token=huggingface_token,
56
+ )
57
+
58
+ # Load Florence-2 model for image captioning
59
+ print("πŸ“₯ Loading Florence-2 model...")
60
  florence_model = AutoModelForCausalLM.from_pretrained(
61
  "microsoft/Florence-2-large",
62
  torch_dtype=torch.float16,
 
66
  florence_processor = AutoProcessor.from_pretrained(
67
  "microsoft/Florence-2-large",
68
  trust_remote_code=True
69
+ )
70
+
71
+ # Load FLUX Img2Img pipeline
72
+ print("πŸ“₯ Loading FLUX Img2Img...")
73
  pipe = FluxImg2ImgPipeline.from_pretrained(
74
  model_path,
75
  torch_dtype=torch.bfloat16
76
  )
77
  pipe.to(device)
78
  pipe.enable_vae_tiling()
79
+ pipe.enable_vae_slicing()
80
+
81
+ print("βœ… All models loaded successfully!")
82
+
83
+ # Download ESRGAN model if using
84
  if USE_ESRGAN:
85
  esrgan_path = "4x-UltraSharp.pth"
86
  if not os.path.exists(esrgan_path):
 
91
  state_dict = torch.load(esrgan_path)['params_ema']
92
  esrgan_model.load_state_dict(state_dict)
93
  esrgan_model.eval()
94
+ esrgan_model.to(device)
95
+
96
+ MAX_SEED = 1000000
97
+ MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
98
+
99
+
100
+ def generate_caption(image):
101
  """Generate detailed caption using Florence-2"""
102
  try:
103
  task_prompt = "<MORE_DETAILED_CAPTION>"
104
+ prompt = task_prompt
105
+
106
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
107
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) # Match model dtype
108
+
109
+ generated_ids = florence_model.generate(
110
+ input_ids=inputs["input_ids"],
111
+ pixel_values=inputs["pixel_values"],
112
+ max_new_tokens=1024,
113
+ num_beams=3,
114
+ do_sample=True,
115
+ )
116
+
117
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
118
+ parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
119
+
120
+ caption = parsed_answer[task_prompt]
121
+ return caption
122
+ except Exception as e:
123
+ print(f"Caption generation failed: {e}")
124
+ return "a high quality detailed image"
125
+
126
+
127
+ def process_input(input_image, upscale_factor):
128
  """Process input image and handle size constraints"""
129
  w, h = input_image.size
130
  w_original, h_original = w, h
131
+ aspect_ratio = w / h
132
 
133
+ was_resized = False
134
+
135
+ if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
136
+ warnings.warn(
137
+ f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
138
+ )
139
+ gr.Info(
140
+ f"Requested output image is too large. Resizing input to fit within pixel budget."
141
+ )
142
+ target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
143
+ scale = (target_input_pixels / (w * h)) ** 0.5
144
+ new_w = int(w * scale) - int(w * scale) % 8
145
+ new_h = int(h * scale) - int(h * scale) % 8
146
+ input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
147
+ was_resized = True
148
+
149
+ return input_image, w_original, h_original, was_resized
150
+
151
+
152
+ def load_image_from_url(url):
153
  """Load image from URL"""
154
  try:
155
  response = requests.get(url, stream=True)
156
  response.raise_for_status()
157
  return Image.open(response.raw)
158
  except Exception as e:
159
+ raise gr.Error(f"Failed to load image from URL: {e}")
160
+
161
+
162
+ def esrgan_upscale(image, scale=4):
163
  if not USE_ESRGAN:
164
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
165
  img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
166
  with torch.no_grad():
167
  output = esrgan_model(img.unsqueeze(0)).squeeze()
168
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
169
+ return Image.fromarray(output_img)
170
+
171
+
172
+ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
173
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
174
  w, h = image.size
175
+ output = image.copy() # Start with the control image
176
+
177
+ # For handling long prompts: truncate for CLIP, full for T5
178
+ max_clip_tokens = pipe.tokenizer.model_max_length # Typically 77
179
+ input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
180
+ if input_ids.shape[1] > max_clip_tokens:
181
+ input_ids = input_ids[:, :max_clip_tokens]
182
+ prompt_clip = pipe.tokenizer.decode(input_ids[0], skip_special_tokens=True)
183
+ else:
184
+ prompt_clip = prompt
185
+
186
+ for x in range(0, w, tile_size - overlap):
187
+ for y in range(0, h, tile_size - overlap):
188
+ tile_w = min(tile_size, w - x)
189
+ tile_h = min(tile_size, h - y)
190
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
191
+
192
+ # Run Flux on tile
193
+ gen_tile = pipe(
194
+ prompt=prompt_clip,
195
+ prompt_2=prompt,
196
+ image=tile,
197
+ strength=strength,
198
+ num_inference_steps=steps,
199
+ guidance_scale=guidance,
200
+ height=tile_h,
201
+ width=tile_w,
202
+ generator=generator,
203
+ ).images[0]
204
+
205
+ # Resize back to exact tile size if pipeline adjusted it
206
+ gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
207
+
208
+ # Paste with blending if overlap
209
+ if overlap > 0:
210
+ paste_box = (x, y, x + tile_w, y + tile_h)
211
+ if x > 0 or y > 0:
212
+ # Simple linear blend on overlaps
213
+ mask = Image.new('L', (tile_w, tile_h), 255)
214
+ if x > 0:
215
+ blend_width = min(overlap, tile_w)
216
+ for i in range(blend_width):
217
+ for j in range(tile_h):
218
+ mask.putpixel((i, j), int(255 * (i / overlap)))
219
+ if y > 0:
220
+ blend_height = min(overlap, tile_h)
221
+ for i in range(tile_w):
222
+ for j in range(blend_height):
223
+ mask.putpixel((i, j), int(255 * (j / overlap)))
224
+ output.paste(gen_tile, paste_box, mask)
225
+ else:
226
+ output.paste(gen_tile, paste_box)
227
  else:
228
+ output.paste(gen_tile, (x, y))
 
 
229
 
230
+ return output
231
+
232
+
233
+ @spaces.GPU(duration=120)
234
  def enhance_image(
235
  image_input,
236
  image_url,
 
237
  randomize_seed,
238
  num_inference_steps,
239
  upscale_factor,
 
249
  elif image_url:
250
  input_image = load_image_from_url(image_url)
251
  else:
252
+ raise gr.Error("Please provide an image (upload or URL)")
 
253
 
254
+ # Convert input image to PNG in backend
255
+ buffer = io.BytesIO()
256
+ input_image.save(buffer, format="PNG")
257
+ buffer.seek(0)
258
+ input_image = Image.open(buffer)
259
 
260
+ if randomize_seed:
261
+ seed = random.randint(0, MAX_SEED)
262
+ else:
263
+ seed = 42
264
 
265
+ true_input_image = input_image
266
+
267
+ # Process input image
268
+ input_image, w_original, h_original, was_resized = process_input(
269
+ input_image, upscale_factor
270
+ )
 
271
 
272
+ # Generate caption if requested
273
+ if use_generated_caption:
274
+ gr.Info("πŸ” Generating image caption...")
275
+ generated_caption = generate_caption(input_image)
276
+ prompt = generated_caption
277
+ else:
278
+ prompt = custom_prompt if custom_prompt.strip() else ""
279
 
280
+ generator = torch.Generator().manual_seed(seed)
281
 
282
+ gr.Info("πŸš€ Upscaling image...")
 
 
 
 
 
283
 
284
+ # Initial upscale
285
+ if USE_ESRGAN and upscale_factor == 4:
286
+ control_image = esrgan_upscale(input_image, upscale_factor)
287
+ else:
288
+ w, h = input_image.size
289
+ control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
290
+
291
+ # Tiled Flux Img2Img for refinement
292
+ image = tiled_flux_img2img(
293
+ pipe,
294
+ prompt,
295
+ control_image,
296
+ denoising_strength,
297
+ num_inference_steps,
298
+ 1.0, # Hardcoded guidance_scale to 1
299
+ generator,
300
+ tile_size=1024,
301
+ overlap=32
302
+ )
303
 
304
+ if was_resized:
305
+ gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
306
+ image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
307
+
308
+ # Resize input image to match output size for slider alignment
309
+ resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
310
+
311
+ return [resized_input, image], image
312
 
 
 
313
 
314
+ # Create Gradio interface
315
+ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FLUX") as demo:
316
  gr.HTML("""
317
  <div class="main-header">
318
+ <h1>🎨 Flux dev Creative Upscaler</h1>
319
+ <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX dev with Ultimate SD Upscaler</p>
320
  <p>Currently running on <strong>{}</strong></p>
321
  </div>
322
+ """.format(power_device))
323
+
324
+ with gr.Row():
325
+ with gr.Column(scale=1):
326
+ gr.HTML("<h3>πŸ“€ Input</h3>")
 
 
 
 
 
 
327
 
328
+ with gr.Tabs():
329
+ with gr.TabItem("πŸ“ Upload Image"):
330
+ input_image = gr.Image(
331
+ label="Upload Image",
332
+ type="pil",
333
+ height=200 # Made smaller
334
+ )
335
+
336
+ with gr.TabItem("πŸ”— Image URL"):
337
+ image_url = gr.Textbox(
338
+ label="Image URL",
339
+ placeholder="https://example.com/image.jpg",
340
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
341
+ )
342
+
343
+ gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
344
+
345
+ use_generated_caption = gr.Checkbox(
346
+ label="Use AI-generated caption (Florence-2)",
347
+ value=True,
348
+ info="Generate detailed caption automatically"
349
+ )
350
+
351
+ custom_prompt = gr.Textbox(
352
+ label="Custom Prompt (optional)",
353
+ placeholder="Enter custom prompt or leave empty for generated caption",
354
+ lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  )
356
+
357
+ gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
358
+
359
+ upscale_factor = gr.Slider(
360
+ label="Upscale Factor",
361
+ minimum=1,
362
+ maximum=4,
363
  step=1,
364
+ value=2,
365
+ info="How much to upscale the image"
366
+ )
367
+
368
+ num_inference_steps = gr.Slider(
369
+ label="Steps (25 Recommended)",
370
+ minimum=8,
371
+ maximum=50,
372
+ step=1,
373
+ value=25,
374
+ info="More steps = better quality but slower"
375
+ )
376
+
377
+ denoising_strength = gr.Slider(
378
+ label="Creativity (Denoising)",
379
+ minimum=0.0,
380
+ maximum=1.0,
381
+ step=0.05,
382
+ value=0.3,
383
+ info="Controls how much the image is transformed"
384
+ )
385
+
386
+ with gr.Row():
387
+ randomize_seed = gr.Checkbox(
388
+ label="Randomize seed",
389
+ value=True
390
+ )
391
+
392
+ enhance_btn = gr.Button(
393
+ "πŸš€ Upscale Image",
394
+ variant="primary",
395
+ size="lg"
396
  )
 
 
 
 
 
 
397
 
398
+ with gr.Column(scale=2): # Larger scale for results
399
+ gr.HTML("<h3>πŸ“Š Results</h3>")
400
+
401
+ result_slider = ImageSlider(
402
+ type="pil",
403
+ interactive=False, # Disable interactivity to prevent uploads
404
+ height=600, # Made larger
405
+ elem_id="result_slider",
406
+ label=None # Remove default label
407
+ )
408
 
409
+ upscaled_output = gr.Image(
410
+ label="Upscaled Image (Download as PNG)",
411
+ type="pil",
412
+ interactive=False,
413
+ show_download_button=True,
414
+ height=600,
415
+ )
 
 
 
 
 
 
 
 
 
416
 
417
+ # Event handler
418
+ enhance_btn.click(
419
+ fn=enhance_image,
420
+ inputs=[
421
+ input_image,
422
+ image_url,
423
+ randomize_seed,
424
+ num_inference_steps,
425
+ upscale_factor,
426
+ denoising_strength,
427
+ use_generated_caption,
428
+ custom_prompt,
429
+ ],
430
+ outputs=[result_slider, upscaled_output]
431
+ )
432
+
433
+ gr.HTML("""
434
+ <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
435
+ <p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
436
+ </div>
437
+ """)
438
+
439
+ # Custom CSS for slider
440
+ gr.HTML("""
441
+ <style>
442
+ #result_slider .slider {
443
+ width: 100% !important;
444
+ max-width: inherit !important;
445
+ }
446
+ #result_slider img {
447
+ object-fit: contain !important;
448
+ width: 100% !important;
449
+ height: auto !important;
450
+ }
451
+ #result_slider .gr-button-tool {
452
+ display: none !important;
453
+ }
454
+ #result_slider .gr-button-undo {
455
+ display: none !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  }
457
+ #result_slider .gr-button-clear {
458
+ display: none !important;
459
+ }
460
+ #result_slider .badge-container .badge {
461
+ display: none !important;
462
+ }
463
+ #result_slider .badge-container::before {
464
+ content: "Before";
465
+ position: absolute;
466
+ top: 10px;
467
+ left: 10px;
468
+ background: rgba(0,0,0,0.5);
469
+ color: white;
470
+ padding: 5px;
471
+ border-radius: 5px;
472
+ z-index: 10;
473
+ }
474
+ #result_slider .badge-container::after {
475
+ content: "After";
476
+ position: absolute;
477
+ top: 10px;
478
+ right: 10px;
479
+ background: rgba(0,0,0,0.5);
480
+ color: white;
481
+ padding: 5px;
482
+ border-radius: 5px;
483
+ z-index: 10;
484
+ }
485
+ #result_slider .fullscreen img {
486
+ object-fit: contain !important;
487
+ width: 100vw !important;
488
+ height: 100vh !important;
489
+ }
490
+ </style>
491
+ """)
492
+
493
+ # JS to set slider default position to middle
494
+ gr.HTML("""
495
+ <script>
496
+ document.addEventListener('DOMContentLoaded', function() {
497
+ const sliderInput = document.querySelector('#result_slider input[type="range"]');
498
+ if (sliderInput) {
499
+ sliderInput.value = 50;
500
+ sliderInput.dispatchEvent(new Event('input'));
501
+ }
502
+ });
503
+ </script>
504
+ """)
505
 
506
+ if __name__ == "__main__":
507
+ demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)