comrender commited on
Commit
d7426bc
Β·
verified Β·
1 Parent(s): 23dd7dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -99
app.py CHANGED
@@ -12,8 +12,6 @@ from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
- import io
16
- import base64
17
 
18
  # For ESRGAN (requires pip install basicsr gfpgan)
19
  try:
@@ -62,7 +60,7 @@ florence_model = AutoModelForCausalLM.from_pretrained(
62
  "microsoft/Florence-2-large",
63
  torch_dtype=torch.float16,
64
  trust_remote_code=True,
65
- attn_implementation="eager"
66
  ).to(device)
67
  florence_processor = AutoProcessor.from_pretrained(
68
  "microsoft/Florence-2-large",
@@ -95,15 +93,16 @@ if USE_ESRGAN:
95
  esrgan_model.to(device)
96
 
97
  MAX_SEED = 1000000
98
- MAX_PIXEL_BUDGET = 8192 * 8192
 
99
 
100
  def generate_caption(image):
101
  """Generate detailed caption using Florence-2"""
102
  try:
103
  task_prompt = "<MORE_DETAILED_CAPTION>"
104
  prompt = task_prompt
 
105
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
106
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
107
 
108
  generated_ids = florence_model.generate(
109
  input_ids=inputs["input_ids"],
@@ -122,10 +121,13 @@ def generate_caption(image):
122
  print(f"Caption generation failed: {e}")
123
  return "a high quality detailed image"
124
 
 
125
  def process_input(input_image, upscale_factor):
126
  """Process input image and handle size constraints"""
127
  w, h = input_image.size
128
  w_original, h_original = w, h
 
 
129
  was_resized = False
130
 
131
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
@@ -144,19 +146,17 @@ def process_input(input_image, upscale_factor):
144
 
145
  return input_image, w_original, h_original, was_resized
146
 
 
147
  def load_image_from_url(url):
148
- """Load image from URL and convert to PNG"""
149
  try:
150
  response = requests.get(url, stream=True)
151
  response.raise_for_status()
152
- img = Image.open(response.raw)
153
- buffer = io.BytesIO()
154
- img.save(buffer, format="PNG")
155
- buffer.seek(0)
156
- return Image.open(buffer)
157
  except Exception as e:
158
  raise gr.Error(f"Failed to load image from URL: {e}")
159
 
 
160
  def esrgan_upscale(image, scale=4):
161
  if not USE_ESRGAN:
162
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
@@ -166,18 +166,11 @@ def esrgan_upscale(image, scale=4):
166
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
167
  return Image.fromarray(output_img)
168
 
 
169
  def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
170
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
171
  w, h = image.size
172
- output = image.copy()
173
-
174
- max_clip_tokens = pipe.tokenizer.model_max_length
175
- input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
176
- if input_ids.shape[1] > max_clip_tokens:
177
- input_ids = input_ids[:, :max_clip_tokens]
178
- prompt_clip = pipe.tokenizer.decode(input_ids[0], skip_special_tokens=True)
179
- else:
180
- prompt_clip = prompt
181
 
182
  for x in range(0, w, tile_size - overlap):
183
  for y in range(0, h, tile_size - overlap):
@@ -185,9 +178,9 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
185
  tile_h = min(tile_size, h - y)
186
  tile = image.crop((x, y, x + tile_w, y + tile_h))
187
 
 
188
  gen_tile = pipe(
189
- prompt=prompt_clip,
190
- prompt_2=prompt,
191
  image=tile,
192
  strength=strength,
193
  num_inference_steps=steps,
@@ -197,21 +190,19 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
197
  generator=generator,
198
  ).images[0]
199
 
200
- gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
201
-
202
  if overlap > 0:
203
  paste_box = (x, y, x + tile_w, y + tile_h)
204
  if x > 0 or y > 0:
 
205
  mask = Image.new('L', (tile_w, tile_h), 255)
206
  if x > 0:
207
- blend_width = min(overlap, tile_w)
208
- for i in range(blend_width):
209
  for j in range(tile_h):
210
  mask.putpixel((i, j), int(255 * (i / overlap)))
211
  if y > 0:
212
- blend_height = min(overlap, tile_h)
213
  for i in range(tile_w):
214
- for j in range(blend_height):
215
  mask.putpixel((i, j), int(255 * (j / overlap)))
216
  output.paste(gen_tile, paste_box, mask)
217
  else:
@@ -221,19 +212,12 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
221
 
222
  return output
223
 
224
- def download_png(image):
225
- """Convert image to PNG and return base64 string for download"""
226
- if image is None:
227
- raise gr.Error("No upscaled image available to download")
228
- buffer = io.BytesIO()
229
- image.save(buffer, format="PNG")
230
- base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
231
- return base64_data
232
 
233
  @spaces.GPU(duration=120)
234
  def enhance_image(
235
  image_input,
236
  image_url,
 
237
  randomize_seed,
238
  num_inference_steps,
239
  upscale_factor,
@@ -243,11 +227,9 @@ def enhance_image(
243
  progress=gr.Progress(track_tqdm=True),
244
  ):
245
  """Main enhancement function"""
 
246
  if image_input is not None:
247
- buffer = io.BytesIO()
248
- image_input.save(buffer, format="PNG")
249
- buffer.seek(0)
250
- input_image = Image.open(buffer)
251
  elif image_url:
252
  input_image = load_image_from_url(image_url)
253
  else:
@@ -255,15 +237,15 @@ def enhance_image(
255
 
256
  if randomize_seed:
257
  seed = random.randint(0, MAX_SEED)
258
- else:
259
- seed = 42
260
 
261
  true_input_image = input_image
262
 
 
263
  input_image, w_original, h_original, was_resized = process_input(
264
  input_image, upscale_factor
265
  )
266
 
 
267
  if use_generated_caption:
268
  gr.Info("πŸ” Generating image caption...")
269
  generated_caption = generate_caption(input_image)
@@ -275,19 +257,21 @@ def enhance_image(
275
 
276
  gr.Info("πŸš€ Upscaling image...")
277
 
 
278
  if USE_ESRGAN and upscale_factor == 4:
279
  control_image = esrgan_upscale(input_image, upscale_factor)
280
  else:
281
  w, h = input_image.size
282
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
283
 
 
284
  image = tiled_flux_img2img(
285
  pipe,
286
  prompt,
287
  control_image,
288
  denoising_strength,
289
  num_inference_steps,
290
- 1.0,
291
  generator,
292
  tile_size=1024,
293
  overlap=32
@@ -297,16 +281,18 @@ def enhance_image(
297
  gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
298
  image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
299
 
 
300
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
301
 
302
- return [resized_input, image], image
 
303
 
304
  # Create Gradio interface
305
- with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FLUX") as demo:
306
  gr.HTML("""
307
  <div class="main-header">
308
- <h1>🎨 Flux dev Creative Upscaler</h1>
309
- <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX dev with Ultimate SD Upscaler</p>
310
  <p>Currently running on <strong>{}</strong></p>
311
  </div>
312
  """.format(power_device))
@@ -320,7 +306,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
320
  input_image = gr.Image(
321
  label="Upload Image",
322
  type="pil",
323
- height=200
324
  )
325
 
326
  with gr.TabItem("πŸ”— Image URL"):
@@ -356,7 +342,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
356
  )
357
 
358
  num_inference_steps = gr.Slider(
359
- label="Steps (25 Recommended)",
360
  minimum=8,
361
  maximum=50,
362
  step=1,
@@ -365,7 +351,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
365
  )
366
 
367
  denoising_strength = gr.Slider(
368
- label="Creativity (Denoising)",
369
  minimum=0.0,
370
  maximum=1.0,
371
  step=0.05,
@@ -378,6 +364,14 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
378
  label="Randomize seed",
379
  value=True
380
  )
 
 
 
 
 
 
 
 
381
 
382
  enhance_btn = gr.Button(
383
  "πŸš€ Upscale Image",
@@ -385,35 +379,24 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
385
  size="lg"
386
  )
387
 
388
- with gr.Column(scale=2):
389
  gr.HTML("<h3>πŸ“Š Results</h3>")
390
 
391
  result_slider = ImageSlider(
392
  type="pil",
393
- interactive=False,
394
- height=600,
395
  elem_id="result_slider",
396
- label=None
397
- )
398
-
399
- download_btn = gr.Button(
400
- "πŸ“₯ Download as PNG",
401
- variant="secondary",
402
- size="lg"
403
  )
404
 
405
- # State to store the upscaled image
406
- upscaled_image_state = gr.State()
407
-
408
- # Hidden textbox for base64 data
409
- download_data = gr.Textbox(visible=False, elem_id="download_data")
410
-
411
- # Event handlers
412
  enhance_btn.click(
413
  fn=enhance_image,
414
  inputs=[
415
  input_image,
416
  image_url,
 
417
  randomize_seed,
418
  num_inference_steps,
419
  upscale_factor,
@@ -421,13 +404,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
421
  use_generated_caption,
422
  custom_prompt,
423
  ],
424
- outputs=[result_slider, upscaled_image_state]
425
- )
426
-
427
- download_btn.click(
428
- fn=download_png,
429
- inputs=[upscaled_image_state],
430
- outputs=download_data
431
  )
432
 
433
  gr.HTML("""
@@ -436,6 +413,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
436
  </div>
437
  """)
438
 
 
439
  gr.HTML("""
440
  <style>
441
  #result_slider .slider {
@@ -489,6 +467,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
489
  </style>
490
  """)
491
 
 
492
  gr.HTML("""
493
  <script>
494
  document.addEventListener('DOMContentLoaded', function() {
@@ -497,31 +476,6 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
497
  sliderInput.value = 50;
498
  sliderInput.dispatchEvent(new Event('input'));
499
  }
500
-
501
- const downloadData = document.querySelector('#download_data textarea');
502
- if (downloadData) {
503
- const observer = new MutationObserver(() => {
504
- const base64 = downloadData.value;
505
- if (base64) {
506
- const byteCharacters = atob(base64);
507
- const byteNumbers = new Array(byteCharacters.length);
508
- for (let i = 0; i < byteCharacters.length; i++) {
509
- byteNumbers[i] = byteCharacters.charCodeAt(i);
510
- }
511
- const byteArray = new Uint8Array(byteNumbers);
512
- const blob = new Blob([byteArray], {type: 'image/png'});
513
- const url = URL.createObjectURL(blob);
514
- const a = document.createElement('a');
515
- a.href = url;
516
- a.download = 'upscaled_image.png';
517
- a.click();
518
- URL.revokeObjectURL(url);
519
- // Clear the textbox
520
- downloadData.value = '';
521
- }
522
- });
523
- observer.observe(downloadData, {childList: true, subtree: true, characterData: true});
524
- }
525
  });
526
  </script>
527
  """)
 
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
 
 
15
 
16
  # For ESRGAN (requires pip install basicsr gfpgan)
17
  try:
 
60
  "microsoft/Florence-2-large",
61
  torch_dtype=torch.float16,
62
  trust_remote_code=True,
63
+ attn_implementation="eager" # Fix for SDPA compatibility issue
64
  ).to(device)
65
  florence_processor = AutoProcessor.from_pretrained(
66
  "microsoft/Florence-2-large",
 
93
  esrgan_model.to(device)
94
 
95
  MAX_SEED = 1000000
96
+ MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
97
+
98
 
99
  def generate_caption(image):
100
  """Generate detailed caption using Florence-2"""
101
  try:
102
  task_prompt = "<MORE_DETAILED_CAPTION>"
103
  prompt = task_prompt
104
+
105
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
 
106
 
107
  generated_ids = florence_model.generate(
108
  input_ids=inputs["input_ids"],
 
121
  print(f"Caption generation failed: {e}")
122
  return "a high quality detailed image"
123
 
124
+
125
  def process_input(input_image, upscale_factor):
126
  """Process input image and handle size constraints"""
127
  w, h = input_image.size
128
  w_original, h_original = w, h
129
+ aspect_ratio = w / h
130
+
131
  was_resized = False
132
 
133
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
 
146
 
147
  return input_image, w_original, h_original, was_resized
148
 
149
+
150
  def load_image_from_url(url):
151
+ """Load image from URL"""
152
  try:
153
  response = requests.get(url, stream=True)
154
  response.raise_for_status()
155
+ return Image.open(response.raw)
 
 
 
 
156
  except Exception as e:
157
  raise gr.Error(f"Failed to load image from URL: {e}")
158
 
159
+
160
  def esrgan_upscale(image, scale=4):
161
  if not USE_ESRGAN:
162
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
 
166
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
167
  return Image.fromarray(output_img)
168
 
169
+
170
  def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
171
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
172
  w, h = image.size
173
+ output = image.copy() # Start with the control image
 
 
 
 
 
 
 
 
174
 
175
  for x in range(0, w, tile_size - overlap):
176
  for y in range(0, h, tile_size - overlap):
 
178
  tile_h = min(tile_size, h - y)
179
  tile = image.crop((x, y, x + tile_w, y + tile_h))
180
 
181
+ # Run Flux on tile
182
  gen_tile = pipe(
183
+ prompt=prompt,
 
184
  image=tile,
185
  strength=strength,
186
  num_inference_steps=steps,
 
190
  generator=generator,
191
  ).images[0]
192
 
193
+ # Paste with blending if overlap
 
194
  if overlap > 0:
195
  paste_box = (x, y, x + tile_w, y + tile_h)
196
  if x > 0 or y > 0:
197
+ # Simple linear blend on overlaps
198
  mask = Image.new('L', (tile_w, tile_h), 255)
199
  if x > 0:
200
+ for i in range(overlap):
 
201
  for j in range(tile_h):
202
  mask.putpixel((i, j), int(255 * (i / overlap)))
203
  if y > 0:
 
204
  for i in range(tile_w):
205
+ for j in range(overlap):
206
  mask.putpixel((i, j), int(255 * (j / overlap)))
207
  output.paste(gen_tile, paste_box, mask)
208
  else:
 
212
 
213
  return output
214
 
 
 
 
 
 
 
 
 
215
 
216
  @spaces.GPU(duration=120)
217
  def enhance_image(
218
  image_input,
219
  image_url,
220
+ seed,
221
  randomize_seed,
222
  num_inference_steps,
223
  upscale_factor,
 
227
  progress=gr.Progress(track_tqdm=True),
228
  ):
229
  """Main enhancement function"""
230
+ # Handle image input
231
  if image_input is not None:
232
+ input_image = image_input
 
 
 
233
  elif image_url:
234
  input_image = load_image_from_url(image_url)
235
  else:
 
237
 
238
  if randomize_seed:
239
  seed = random.randint(0, MAX_SEED)
 
 
240
 
241
  true_input_image = input_image
242
 
243
+ # Process input image
244
  input_image, w_original, h_original, was_resized = process_input(
245
  input_image, upscale_factor
246
  )
247
 
248
+ # Generate caption if requested
249
  if use_generated_caption:
250
  gr.Info("πŸ” Generating image caption...")
251
  generated_caption = generate_caption(input_image)
 
257
 
258
  gr.Info("πŸš€ Upscaling image...")
259
 
260
+ # Initial upscale
261
  if USE_ESRGAN and upscale_factor == 4:
262
  control_image = esrgan_upscale(input_image, upscale_factor)
263
  else:
264
  w, h = input_image.size
265
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
266
 
267
+ # Tiled Flux Img2Img for refinement
268
  image = tiled_flux_img2img(
269
  pipe,
270
  prompt,
271
  control_image,
272
  denoising_strength,
273
  num_inference_steps,
274
+ 1.0, # Hardcoded guidance_scale to 1
275
  generator,
276
  tile_size=1024,
277
  overlap=32
 
281
  gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
282
  image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
283
 
284
+ # Resize input image to match output size for slider alignment
285
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
286
 
287
+ return [resized_input, image]
288
+
289
 
290
  # Create Gradio interface
291
+ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
292
  gr.HTML("""
293
  <div class="main-header">
294
+ <h1>🎨 AI Image Upscaler</h1>
295
+ <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
296
  <p>Currently running on <strong>{}</strong></p>
297
  </div>
298
  """.format(power_device))
 
306
  input_image = gr.Image(
307
  label="Upload Image",
308
  type="pil",
309
+ height=200 # Made smaller
310
  )
311
 
312
  with gr.TabItem("πŸ”— Image URL"):
 
342
  )
343
 
344
  num_inference_steps = gr.Slider(
345
+ label="Number of Inference Steps",
346
  minimum=8,
347
  maximum=50,
348
  step=1,
 
351
  )
352
 
353
  denoising_strength = gr.Slider(
354
+ label="Denoising Strength",
355
  minimum=0.0,
356
  maximum=1.0,
357
  step=0.05,
 
364
  label="Randomize seed",
365
  value=True
366
  )
367
+ seed = gr.Slider(
368
+ label="Seed",
369
+ minimum=0,
370
+ maximum=MAX_SEED,
371
+ step=1,
372
+ value=42,
373
+ interactive=True
374
+ )
375
 
376
  enhance_btn = gr.Button(
377
  "πŸš€ Upscale Image",
 
379
  size="lg"
380
  )
381
 
382
+ with gr.Column(scale=2): # Larger scale for results
383
  gr.HTML("<h3>πŸ“Š Results</h3>")
384
 
385
  result_slider = ImageSlider(
386
  type="pil",
387
+ interactive=False, # Disable interactivity to prevent uploads
388
+ height=600, # Made larger
389
  elem_id="result_slider",
390
+ label=None # Remove default label
 
 
 
 
 
 
391
  )
392
 
393
+ # Event handler
 
 
 
 
 
 
394
  enhance_btn.click(
395
  fn=enhance_image,
396
  inputs=[
397
  input_image,
398
  image_url,
399
+ seed,
400
  randomize_seed,
401
  num_inference_steps,
402
  upscale_factor,
 
404
  use_generated_caption,
405
  custom_prompt,
406
  ],
407
+ outputs=[result_slider]
 
 
 
 
 
 
408
  )
409
 
410
  gr.HTML("""
 
413
  </div>
414
  """)
415
 
416
+ # Custom CSS for slider
417
  gr.HTML("""
418
  <style>
419
  #result_slider .slider {
 
467
  </style>
468
  """)
469
 
470
+ # JS to set slider default position to middle
471
  gr.HTML("""
472
  <script>
473
  document.addEventListener('DOMContentLoaded', function() {
 
476
  sliderInput.value = 50;
477
  sliderInput.dispatchEvent(new Event('input'));
478
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  });
480
  </script>
481
  """)