isat commited on
Commit
cf46bb8
·
verified ·
1 Parent(s): a9b6162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -281,16 +281,13 @@ image_mask_list = sorted([os.path.join(image_mask_dir, f) for f in os.listdir(im
281
 
282
 
283
  @spaces.GPU
284
- def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt):
285
  if base_mask_option == "Draw Mask":
286
  tar_image = base_image["background"]
287
  tar_mask = base_image["layers"][0]
288
- elif base_mask_option == "Upload with Mask":
289
  tar_image = base_image["background"]
290
  tar_mask = base_mask["background"]
291
- else: # Label to Mask
292
- tar_image = base_image["background"]
293
- tar_mask = get_mask(tar_image, base_text_prompt)
294
 
295
  if ref_mask_option == "Draw Mask":
296
  ref_image = reference_image["background"]
@@ -298,9 +295,9 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
298
  elif ref_mask_option == "Upload with Mask":
299
  ref_image = reference_image["background"]
300
  ref_mask = ref_mask["background"]
301
- else: # Label to Mask
302
  ref_image = reference_image["background"]
303
- ref_mask = get_mask(ref_image, ref_text_prompt)
304
 
305
  tar_image = tar_image.convert("RGB")
306
  tar_mask = tar_mask.convert("L")
@@ -396,16 +393,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
396
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
397
  edited_image = Image.fromarray(edited_image)
398
 
399
- # Determine which masks to show as "generated" in output
400
- masks_to_return = []
401
- if base_mask_option == "Label to Mask":
402
- masks_to_return.append(received_tar_mask) # Show generated background mask
403
- if ref_mask_option == "Label to Mask":
404
- masks_to_return.append(received_ref_mask) # Show generated reference mask
405
-
406
- # Build return list: generated_masks + diptych + final_image + received_masks
407
- return_list = masks_to_return + [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
408
- return return_list
409
 
410
 
411
  def update_ui(option):
@@ -429,11 +420,8 @@ with gr.Blocks() as demo:
429
  base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
430
  layers=False, brush=False, eraser=False)
431
  with gr.Row():
432
- base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Background Mask Input Option",
433
  value="Upload with Mask")
434
- with gr.Row():
435
- base_text_prompt = gr.Textbox(label="Background Label",
436
- placeholder="Enter the category to mask in background, e.g., sofa, table, person, etc.")
437
 
438
  with gr.Row():
439
  ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
@@ -446,8 +434,8 @@ with gr.Blocks() as demo:
446
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
447
  label="Reference Mask Input Option", value="Upload with Mask")
448
  with gr.Row():
449
- ref_text_prompt = gr.Textbox(label="Reference Label",
450
- placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
451
 
452
  with gr.Column(scale=1):
453
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
@@ -458,7 +446,6 @@ with gr.Blocks() as demo:
458
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
459
  gr.Markdown(" Upload with Mask means uploading a mask file.")
460
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
461
- gr.Markdown(" Both background and reference images now support all three masking options including automatic mask generation from labels.")
462
 
463
  run_local_button = gr.Button(value="Run")
464
 
@@ -481,7 +468,7 @@ with gr.Blocks() as demo:
481
 
482
  run_local_button.click(
483
  fn=run_local,
484
- inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt],
485
  outputs=[baseline_gallery]
486
  )
487
  demo.launch()
 
281
 
282
 
283
  @spaces.GPU
284
+ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt):
285
  if base_mask_option == "Draw Mask":
286
  tar_image = base_image["background"]
287
  tar_mask = base_image["layers"][0]
288
+ else:
289
  tar_image = base_image["background"]
290
  tar_mask = base_mask["background"]
 
 
 
291
 
292
  if ref_mask_option == "Draw Mask":
293
  ref_image = reference_image["background"]
 
295
  elif ref_mask_option == "Upload with Mask":
296
  ref_image = reference_image["background"]
297
  ref_mask = ref_mask["background"]
298
+ else:
299
  ref_image = reference_image["background"]
300
+ ref_mask = get_mask(ref_image, text_prompt)
301
 
302
  tar_image = tar_image.convert("RGB")
303
  tar_mask = tar_mask.convert("L")
 
393
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
394
  edited_image = Image.fromarray(edited_image)
395
 
396
+ if ref_mask_option != "Label to Mask":
397
+ return [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
398
+ else:
399
+ return [return_ref_mask, show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
 
 
 
 
 
 
400
 
401
 
402
  def update_ui(option):
 
420
  base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
421
  layers=False, brush=False, eraser=False)
422
  with gr.Row():
423
+ base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option",
424
  value="Upload with Mask")
 
 
 
425
 
426
  with gr.Row():
427
  ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
 
434
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
435
  label="Reference Mask Input Option", value="Upload with Mask")
436
  with gr.Row():
437
+ text_prompt = gr.Textbox(label="Label",
438
+ placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
439
 
440
  with gr.Column(scale=1):
441
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
 
446
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
447
  gr.Markdown(" Upload with Mask means uploading a mask file.")
448
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
 
449
 
450
  run_local_button = gr.Button(value="Run")
451
 
 
468
 
469
  run_local_button.click(
470
  fn=run_local,
471
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
472
  outputs=[baseline_gallery]
473
  )
474
  demo.launch()