comrender commited on
Commit
da3febd
Β·
verified Β·
1 Parent(s): a1ef78c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -51
app.py CHANGED
@@ -13,6 +13,7 @@ from PIL import Image
13
  from huggingface_hub import snapshot_download
14
  import requests
15
  import io
 
16
 
17
  # For ESRGAN (requires pip install basicsr gfpgan)
18
  try:
@@ -61,7 +62,7 @@ florence_model = AutoModelForCausalLM.from_pretrained(
61
  "microsoft/Florence-2-large",
62
  torch_dtype=torch.float16,
63
  trust_remote_code=True,
64
- attn_implementation="eager" # Fix for SDPA compatibility issue
65
  ).to(device)
66
  florence_processor = AutoProcessor.from_pretrained(
67
  "microsoft/Florence-2-large",
@@ -94,17 +95,15 @@ if USE_ESRGAN:
94
  esrgan_model.to(device)
95
 
96
  MAX_SEED = 1000000
97
- MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
98
-
99
 
100
  def generate_caption(image):
101
  """Generate detailed caption using Florence-2"""
102
  try:
103
  task_prompt = "<MORE_DETAILED_CAPTION>"
104
  prompt = task_prompt
105
-
106
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
107
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) # Match model dtype
108
 
109
  generated_ids = florence_model.generate(
110
  input_ids=inputs["input_ids"],
@@ -123,13 +122,10 @@ def generate_caption(image):
123
  print(f"Caption generation failed: {e}")
124
  return "a high quality detailed image"
125
 
126
-
127
  def process_input(input_image, upscale_factor):
128
  """Process input image and handle size constraints"""
129
  w, h = input_image.size
130
  w_original, h_original = w, h
131
- aspect_ratio = w / h
132
-
133
  was_resized = False
134
 
135
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
@@ -148,17 +144,19 @@ def process_input(input_image, upscale_factor):
148
 
149
  return input_image, w_original, h_original, was_resized
150
 
151
-
152
  def load_image_from_url(url):
153
- """Load image from URL"""
154
  try:
155
  response = requests.get(url, stream=True)
156
  response.raise_for_status()
157
- return Image.open(response.raw)
 
 
 
 
158
  except Exception as e:
159
  raise gr.Error(f"Failed to load image from URL: {e}")
160
 
161
-
162
  def esrgan_upscale(image, scale=4):
163
  if not USE_ESRGAN:
164
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
@@ -168,14 +166,12 @@ def esrgan_upscale(image, scale=4):
168
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
169
  return Image.fromarray(output_img)
170
 
171
-
172
  def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
173
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
174
  w, h = image.size
175
- output = image.copy() # Start with the control image
176
 
177
- # For handling long prompts: truncate for CLIP, full for T5
178
- max_clip_tokens = pipe.tokenizer.model_max_length # Typically 77
179
  input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
180
  if input_ids.shape[1] > max_clip_tokens:
181
  input_ids = input_ids[:, :max_clip_tokens]
@@ -189,7 +185,6 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
189
  tile_h = min(tile_size, h - y)
190
  tile = image.crop((x, y, x + tile_w, y + tile_h))
191
 
192
- # Run Flux on tile
193
  gen_tile = pipe(
194
  prompt=prompt_clip,
195
  prompt_2=prompt,
@@ -202,14 +197,11 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
202
  generator=generator,
203
  ).images[0]
204
 
205
- # Resize back to exact tile size if pipeline adjusted it
206
  gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
207
 
208
- # Paste with blending if overlap
209
  if overlap > 0:
210
  paste_box = (x, y, x + tile_w, y + tile_h)
211
  if x > 0 or y > 0:
212
- # Simple linear blend on overlaps
213
  mask = Image.new('L', (tile_w, tile_h), 255)
214
  if x > 0:
215
  blend_width = min(overlap, tile_w)
@@ -229,6 +221,14 @@ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator
229
 
230
  return output
231
 
 
 
 
 
 
 
 
 
232
 
233
  @spaces.GPU(duration=120)
234
  def enhance_image(
@@ -243,20 +243,16 @@ def enhance_image(
243
  progress=gr.Progress(track_tqdm=True),
244
  ):
245
  """Main enhancement function"""
246
- # Handle image input
247
  if image_input is not None:
248
- input_image = image_input
 
 
 
249
  elif image_url:
250
  input_image = load_image_from_url(image_url)
251
  else:
252
  raise gr.Error("Please provide an image (upload or URL)")
253
 
254
- # Convert input image to PNG in backend
255
- buffer = io.BytesIO()
256
- input_image.save(buffer, format="PNG")
257
- buffer.seek(0)
258
- input_image = Image.open(buffer)
259
-
260
  if randomize_seed:
261
  seed = random.randint(0, MAX_SEED)
262
  else:
@@ -264,12 +260,10 @@ def enhance_image(
264
 
265
  true_input_image = input_image
266
 
267
- # Process input image
268
  input_image, w_original, h_original, was_resized = process_input(
269
  input_image, upscale_factor
270
  )
271
 
272
- # Generate caption if requested
273
  if use_generated_caption:
274
  gr.Info("πŸ” Generating image caption...")
275
  generated_caption = generate_caption(input_image)
@@ -281,21 +275,19 @@ def enhance_image(
281
 
282
  gr.Info("πŸš€ Upscaling image...")
283
 
284
- # Initial upscale
285
  if USE_ESRGAN and upscale_factor == 4:
286
  control_image = esrgan_upscale(input_image, upscale_factor)
287
  else:
288
  w, h = input_image.size
289
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
290
 
291
- # Tiled Flux Img2Img for refinement
292
  image = tiled_flux_img2img(
293
  pipe,
294
  prompt,
295
  control_image,
296
  denoising_strength,
297
  num_inference_steps,
298
- 1.0, # Hardcoded guidance_scale to 1
299
  generator,
300
  tile_size=1024,
301
  overlap=32
@@ -305,12 +297,10 @@ def enhance_image(
305
  gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
306
  image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
307
 
308
- # Resize input image to match output size for slider alignment
309
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
310
 
311
  return [resized_input, image], image
312
 
313
-
314
  # Create Gradio interface
315
  with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FLUX") as demo:
316
  gr.HTML("""
@@ -330,7 +320,7 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
330
  input_image = gr.Image(
331
  label="Upload Image",
332
  type="pil",
333
- height=200 # Made smaller
334
  )
335
 
336
  with gr.TabItem("πŸ”— Image URL"):
@@ -395,26 +385,27 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
395
  size="lg"
396
  )
397
 
398
- with gr.Column(scale=2): # Larger scale for results
399
  gr.HTML("<h3>πŸ“Š Results</h3>")
400
 
401
  result_slider = ImageSlider(
402
- type="pil",
403
- interactive=False, # Disable interactivity to prevent uploads
404
- height=600, # Made larger
405
- elem_id="result_slider",
406
- label=None # Remove default label
407
- )
408
-
409
- upscaled_output = gr.Image(
410
- label="Upscaled Image (Download as PNG)",
411
  type="pil",
412
  interactive=False,
413
- show_download_button=True,
414
  height=600,
 
 
 
 
 
 
 
 
415
  )
416
 
417
- # Event handler
 
 
 
418
  enhance_btn.click(
419
  fn=enhance_image,
420
  inputs=[
@@ -427,7 +418,13 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
427
  use_generated_caption,
428
  custom_prompt,
429
  ],
430
- outputs=[result_slider, upscaled_output]
 
 
 
 
 
 
431
  )
432
 
433
  gr.HTML("""
@@ -436,7 +433,6 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
436
  </div>
437
  """)
438
 
439
- # Custom CSS for slider
440
  gr.HTML("""
441
  <style>
442
  #result_slider .slider {
@@ -490,7 +486,6 @@ with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FL
490
  </style>
491
  """)
492
 
493
- # JS to set slider default position to middle
494
  gr.HTML("""
495
  <script>
496
  document.addEventListener('DOMContentLoaded', function() {
 
13
  from huggingface_hub import snapshot_download
14
  import requests
15
  import io
16
+ import base64
17
 
18
  # For ESRGAN (requires pip install basicsr gfpgan)
19
  try:
 
62
  "microsoft/Florence-2-large",
63
  torch_dtype=torch.float16,
64
  trust_remote_code=True,
65
+ attn_implementation="eager"
66
  ).to(device)
67
  florence_processor = AutoProcessor.from_pretrained(
68
  "microsoft/Florence-2-large",
 
95
  esrgan_model.to(device)
96
 
97
  MAX_SEED = 1000000
98
+ MAX_PIXEL_BUDGET = 8192 * 8192
 
99
 
100
  def generate_caption(image):
101
  """Generate detailed caption using Florence-2"""
102
  try:
103
  task_prompt = "<MORE_DETAILED_CAPTION>"
104
  prompt = task_prompt
 
105
  inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
106
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
107
 
108
  generated_ids = florence_model.generate(
109
  input_ids=inputs["input_ids"],
 
122
  print(f"Caption generation failed: {e}")
123
  return "a high quality detailed image"
124
 
 
125
  def process_input(input_image, upscale_factor):
126
  """Process input image and handle size constraints"""
127
  w, h = input_image.size
128
  w_original, h_original = w, h
 
 
129
  was_resized = False
130
 
131
  if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
 
144
 
145
  return input_image, w_original, h_original, was_resized
146
 
 
147
  def load_image_from_url(url):
148
+ """Load image from URL and convert to PNG"""
149
  try:
150
  response = requests.get(url, stream=True)
151
  response.raise_for_status()
152
+ img = Image.open(response.raw)
153
+ buffer = io.BytesIO()
154
+ img.save(buffer, format="PNG")
155
+ buffer.seek(0)
156
+ return Image.open(buffer)
157
  except Exception as e:
158
  raise gr.Error(f"Failed to load image from URL: {e}")
159
 
 
160
  def esrgan_upscale(image, scale=4):
161
  if not USE_ESRGAN:
162
  return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
 
166
  output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
167
  return Image.fromarray(output_img)
168
 
 
169
  def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
170
  """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
171
  w, h = image.size
172
+ output = image.copy()
173
 
174
+ max_clip_tokens = pipe.tokenizer.model_max_length
 
175
  input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
176
  if input_ids.shape[1] > max_clip_tokens:
177
  input_ids = input_ids[:, :max_clip_tokens]
 
185
  tile_h = min(tile_size, h - y)
186
  tile = image.crop((x, y, x + tile_w, y + tile_h))
187
 
 
188
  gen_tile = pipe(
189
  prompt=prompt_clip,
190
  prompt_2=prompt,
 
197
  generator=generator,
198
  ).images[0]
199
 
 
200
  gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
201
 
 
202
  if overlap > 0:
203
  paste_box = (x, y, x + tile_w, y + tile_h)
204
  if x > 0 or y > 0:
 
205
  mask = Image.new('L', (tile_w, tile_h), 255)
206
  if x > 0:
207
  blend_width = min(overlap, tile_w)
 
221
 
222
  return output
223
 
224
+ def download_png(image):
225
+ """Convert image to PNG and return as downloadable file"""
226
+ if image is None:
227
+ raise gr.Error("No upscaled image available to download")
228
+ buffer = io.BytesIO()
229
+ image.save(buffer, format="PNG")
230
+ buffer.seek(0)
231
+ return buffer
232
 
233
  @spaces.GPU(duration=120)
234
  def enhance_image(
 
243
  progress=gr.Progress(track_tqdm=True),
244
  ):
245
  """Main enhancement function"""
 
246
  if image_input is not None:
247
+ buffer = io.BytesIO()
248
+ image_input.save(buffer, format="PNG")
249
+ buffer.seek(0)
250
+ input_image = Image.open(buffer)
251
  elif image_url:
252
  input_image = load_image_from_url(image_url)
253
  else:
254
  raise gr.Error("Please provide an image (upload or URL)")
255
 
 
 
 
 
 
 
256
  if randomize_seed:
257
  seed = random.randint(0, MAX_SEED)
258
  else:
 
260
 
261
  true_input_image = input_image
262
 
 
263
  input_image, w_original, h_original, was_resized = process_input(
264
  input_image, upscale_factor
265
  )
266
 
 
267
  if use_generated_caption:
268
  gr.Info("πŸ” Generating image caption...")
269
  generated_caption = generate_caption(input_image)
 
275
 
276
  gr.Info("πŸš€ Upscaling image...")
277
 
 
278
  if USE_ESRGAN and upscale_factor == 4:
279
  control_image = esrgan_upscale(input_image, upscale_factor)
280
  else:
281
  w, h = input_image.size
282
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
283
 
 
284
  image = tiled_flux_img2img(
285
  pipe,
286
  prompt,
287
  control_image,
288
  denoising_strength,
289
  num_inference_steps,
290
+ 1.0,
291
  generator,
292
  tile_size=1024,
293
  overlap=32
 
297
  gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
298
  image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
299
 
 
300
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
301
 
302
  return [resized_input, image], image
303
 
 
304
  # Create Gradio interface
305
  with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FLUX") as demo:
306
  gr.HTML("""
 
320
  input_image = gr.Image(
321
  label="Upload Image",
322
  type="pil",
323
+ height=200
324
  )
325
 
326
  with gr.TabItem("πŸ”— Image URL"):
 
385
  size="lg"
386
  )
387
 
388
+ with gr.Column(scale=2):
389
  gr.HTML("<h3>πŸ“Š Results</h3>")
390
 
391
  result_slider = ImageSlider(
 
 
 
 
 
 
 
 
 
392
  type="pil",
393
  interactive=False,
 
394
  height=600,
395
+ elem_id="result_slider",
396
+ label=None
397
+ )
398
+
399
+ download_btn = gr.Button(
400
+ "πŸ“₯ Download as PNG",
401
+ variant="secondary",
402
+ size="lg"
403
  )
404
 
405
+ # State to store the upscaled image
406
+ upscaled_image_state = gr.State()
407
+
408
+ # Event handlers
409
  enhance_btn.click(
410
  fn=enhance_image,
411
  inputs=[
 
418
  use_generated_caption,
419
  custom_prompt,
420
  ],
421
+ outputs=[result_slider, upscaled_image_state]
422
+ )
423
+
424
+ download_btn.click(
425
+ fn=download_png,
426
+ inputs=[upscaled_image_state],
427
+ outputs=gr.File(label="Download Upscaled Image as PNG")
428
  )
429
 
430
  gr.HTML("""
 
433
  </div>
434
  """)
435
 
 
436
  gr.HTML("""
437
  <style>
438
  #result_slider .slider {
 
486
  </style>
487
  """)
488
 
 
489
  gr.HTML("""
490
  <script>
491
  document.addEventListener('DOMContentLoaded', function() {