Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -281,13 +281,16 @@ image_mask_list = sorted([os.path.join(image_mask_dir, f) for f in os.listdir(im
|
|
281 |
|
282 |
|
283 |
@spaces.GPU
|
284 |
-
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option,
|
285 |
if base_mask_option == "Draw Mask":
|
286 |
tar_image = base_image["background"]
|
287 |
tar_mask = base_image["layers"][0]
|
288 |
-
|
289 |
tar_image = base_image["background"]
|
290 |
tar_mask = base_mask["background"]
|
|
|
|
|
|
|
291 |
|
292 |
if ref_mask_option == "Draw Mask":
|
293 |
ref_image = reference_image["background"]
|
@@ -295,9 +298,9 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
295 |
elif ref_mask_option == "Upload with Mask":
|
296 |
ref_image = reference_image["background"]
|
297 |
ref_mask = ref_mask["background"]
|
298 |
-
else:
|
299 |
ref_image = reference_image["background"]
|
300 |
-
ref_mask = get_mask(ref_image,
|
301 |
|
302 |
tar_image = tar_image.convert("RGB")
|
303 |
tar_mask = tar_mask.convert("L")
|
@@ -393,10 +396,16 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
393 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
394 |
edited_image = Image.fromarray(edited_image)
|
395 |
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
|
402 |
def update_ui(option):
|
@@ -420,8 +429,11 @@ with gr.Blocks() as demo:
|
|
420 |
base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
|
421 |
layers=False, brush=False, eraser=False)
|
422 |
with gr.Row():
|
423 |
-
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option",
|
424 |
value="Upload with Mask")
|
|
|
|
|
|
|
425 |
|
426 |
with gr.Row():
|
427 |
ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
|
@@ -434,8 +446,8 @@ with gr.Blocks() as demo:
|
|
434 |
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
|
435 |
label="Reference Mask Input Option", value="Upload with Mask")
|
436 |
with gr.Row():
|
437 |
-
|
438 |
-
|
439 |
|
440 |
with gr.Column(scale=1):
|
441 |
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
|
@@ -446,6 +458,7 @@ with gr.Blocks() as demo:
|
|
446 |
gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
|
447 |
gr.Markdown(" Upload with Mask means uploading a mask file.")
|
448 |
gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
|
|
|
449 |
|
450 |
run_local_button = gr.Button(value="Run")
|
451 |
|
@@ -468,7 +481,7 @@ with gr.Blocks() as demo:
|
|
468 |
|
469 |
run_local_button.click(
|
470 |
fn=run_local,
|
471 |
-
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option,
|
472 |
outputs=[baseline_gallery]
|
473 |
)
|
474 |
demo.launch()
|
|
|
281 |
|
282 |
|
283 |
@spaces.GPU
|
284 |
+
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt):
|
285 |
if base_mask_option == "Draw Mask":
|
286 |
tar_image = base_image["background"]
|
287 |
tar_mask = base_image["layers"][0]
|
288 |
+
elif base_mask_option == "Upload with Mask":
|
289 |
tar_image = base_image["background"]
|
290 |
tar_mask = base_mask["background"]
|
291 |
+
else: # Label to Mask
|
292 |
+
tar_image = base_image["background"]
|
293 |
+
tar_mask = get_mask(tar_image, base_text_prompt)
|
294 |
|
295 |
if ref_mask_option == "Draw Mask":
|
296 |
ref_image = reference_image["background"]
|
|
|
298 |
elif ref_mask_option == "Upload with Mask":
|
299 |
ref_image = reference_image["background"]
|
300 |
ref_mask = ref_mask["background"]
|
301 |
+
else: # Label to Mask
|
302 |
ref_image = reference_image["background"]
|
303 |
+
ref_mask = get_mask(ref_image, ref_text_prompt)
|
304 |
|
305 |
tar_image = tar_image.convert("RGB")
|
306 |
tar_mask = tar_mask.convert("L")
|
|
|
396 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
397 |
edited_image = Image.fromarray(edited_image)
|
398 |
|
399 |
+
# Determine which masks to show as "generated" in output
|
400 |
+
masks_to_return = []
|
401 |
+
if base_mask_option == "Label to Mask":
|
402 |
+
masks_to_return.append(received_tar_mask) # Show generated background mask
|
403 |
+
if ref_mask_option == "Label to Mask":
|
404 |
+
masks_to_return.append(received_ref_mask) # Show generated reference mask
|
405 |
+
|
406 |
+
# Build return list: generated_masks + diptych + final_image + received_masks
|
407 |
+
return_list = masks_to_return + [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
|
408 |
+
return return_list
|
409 |
|
410 |
|
411 |
def update_ui(option):
|
|
|
429 |
base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
|
430 |
layers=False, brush=False, eraser=False)
|
431 |
with gr.Row():
|
432 |
+
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Background Mask Input Option",
|
433 |
value="Upload with Mask")
|
434 |
+
with gr.Row():
|
435 |
+
base_text_prompt = gr.Textbox(label="Background Label",
|
436 |
+
placeholder="Enter the category to mask in background, e.g., sofa, table, person, etc.")
|
437 |
|
438 |
with gr.Row():
|
439 |
ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
|
|
|
446 |
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
|
447 |
label="Reference Mask Input Option", value="Upload with Mask")
|
448 |
with gr.Row():
|
449 |
+
ref_text_prompt = gr.Textbox(label="Reference Label",
|
450 |
+
placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
|
451 |
|
452 |
with gr.Column(scale=1):
|
453 |
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
|
|
|
458 |
gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
|
459 |
gr.Markdown(" Upload with Mask means uploading a mask file.")
|
460 |
gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
|
461 |
+
gr.Markdown(" Both background and reference images now support all three masking options including automatic mask generation from labels.")
|
462 |
|
463 |
run_local_button = gr.Button(value="Run")
|
464 |
|
|
|
481 |
|
482 |
run_local_button.click(
|
483 |
fn=run_local,
|
484 |
+
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt],
|
485 |
outputs=[baseline_gallery]
|
486 |
)
|
487 |
demo.launch()
|