isat commited on
Commit
a9b6162
·
verified ·
1 Parent(s): 543dd2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -281,13 +281,16 @@ image_mask_list = sorted([os.path.join(image_mask_dir, f) for f in os.listdir(im
281
 
282
 
283
  @spaces.GPU
284
- def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt):
285
  if base_mask_option == "Draw Mask":
286
  tar_image = base_image["background"]
287
  tar_mask = base_image["layers"][0]
288
- else:
289
  tar_image = base_image["background"]
290
  tar_mask = base_mask["background"]
 
 
 
291
 
292
  if ref_mask_option == "Draw Mask":
293
  ref_image = reference_image["background"]
@@ -295,9 +298,9 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
295
  elif ref_mask_option == "Upload with Mask":
296
  ref_image = reference_image["background"]
297
  ref_mask = ref_mask["background"]
298
- else:
299
  ref_image = reference_image["background"]
300
- ref_mask = get_mask(ref_image, text_prompt)
301
 
302
  tar_image = tar_image.convert("RGB")
303
  tar_mask = tar_mask.convert("L")
@@ -393,10 +396,16 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
393
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
394
  edited_image = Image.fromarray(edited_image)
395
 
396
- if ref_mask_option != "Label to Mask":
397
- return [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
398
- else:
399
- return [return_ref_mask, show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
 
 
 
 
 
 
400
 
401
 
402
  def update_ui(option):
@@ -420,8 +429,11 @@ with gr.Blocks() as demo:
420
  base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
421
  layers=False, brush=False, eraser=False)
422
  with gr.Row():
423
- base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option",
424
  value="Upload with Mask")
 
 
 
425
 
426
  with gr.Row():
427
  ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
@@ -434,8 +446,8 @@ with gr.Blocks() as demo:
434
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
435
  label="Reference Mask Input Option", value="Upload with Mask")
436
  with gr.Row():
437
- text_prompt = gr.Textbox(label="Label",
438
- placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
439
 
440
  with gr.Column(scale=1):
441
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
@@ -446,6 +458,7 @@ with gr.Blocks() as demo:
446
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
447
  gr.Markdown(" Upload with Mask means uploading a mask file.")
448
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
 
449
 
450
  run_local_button = gr.Button(value="Run")
451
 
@@ -468,7 +481,7 @@ with gr.Blocks() as demo:
468
 
469
  run_local_button.click(
470
  fn=run_local,
471
- inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
472
  outputs=[baseline_gallery]
473
  )
474
  demo.launch()
 
281
 
282
 
283
  @spaces.GPU
284
+ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt):
285
  if base_mask_option == "Draw Mask":
286
  tar_image = base_image["background"]
287
  tar_mask = base_image["layers"][0]
288
+ elif base_mask_option == "Upload with Mask":
289
  tar_image = base_image["background"]
290
  tar_mask = base_mask["background"]
291
+ else: # Label to Mask
292
+ tar_image = base_image["background"]
293
+ tar_mask = get_mask(tar_image, base_text_prompt)
294
 
295
  if ref_mask_option == "Draw Mask":
296
  ref_image = reference_image["background"]
 
298
  elif ref_mask_option == "Upload with Mask":
299
  ref_image = reference_image["background"]
300
  ref_mask = ref_mask["background"]
301
+ else: # Label to Mask
302
  ref_image = reference_image["background"]
303
+ ref_mask = get_mask(ref_image, ref_text_prompt)
304
 
305
  tar_image = tar_image.convert("RGB")
306
  tar_mask = tar_mask.convert("L")
 
396
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
397
  edited_image = Image.fromarray(edited_image)
398
 
399
+ # Determine which masks to show as "generated" in output
400
+ masks_to_return = []
401
+ if base_mask_option == "Label to Mask":
402
+ masks_to_return.append(received_tar_mask) # Show generated background mask
403
+ if ref_mask_option == "Label to Mask":
404
+ masks_to_return.append(received_ref_mask) # Show generated reference mask
405
+
406
+ # Build return list: generated_masks + diptych + final_image + received_masks
407
+ return_list = masks_to_return + [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
408
+ return return_list
409
 
410
 
411
  def update_ui(option):
 
429
  base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
430
  layers=False, brush=False, eraser=False)
431
  with gr.Row():
432
+ base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Background Mask Input Option",
433
  value="Upload with Mask")
434
+ with gr.Row():
435
+ base_text_prompt = gr.Textbox(label="Background Label",
436
+ placeholder="Enter the category to mask in background, e.g., sofa, table, person, etc.")
437
 
438
  with gr.Row():
439
  ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
 
446
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
447
  label="Reference Mask Input Option", value="Upload with Mask")
448
  with gr.Row():
449
+ ref_text_prompt = gr.Textbox(label="Reference Label",
450
+ placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
451
 
452
  with gr.Column(scale=1):
453
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
 
458
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
459
  gr.Markdown(" Upload with Mask means uploading a mask file.")
460
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
461
+ gr.Markdown(" Both background and reference images now support all three masking options including automatic mask generation from labels.")
462
 
463
  run_local_button = gr.Button(value="Run")
464
 
 
481
 
482
  run_local_button.click(
483
  fn=run_local,
484
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt],
485
  outputs=[baseline_gallery]
486
  )
487
  demo.launch()