Spaces:
Runtime error
Runtime error
Update app.py
Browse filesUpdated the app to accept mask generation using prompt for both original and reference images
app.py
CHANGED
|
@@ -281,13 +281,16 @@ image_mask_list = sorted([os.path.join(image_mask_dir, f) for f in os.listdir(im
|
|
| 281 |
|
| 282 |
|
| 283 |
@spaces.GPU
|
| 284 |
-
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option,
|
| 285 |
if base_mask_option == "Draw Mask":
|
| 286 |
tar_image = base_image["background"]
|
| 287 |
tar_mask = base_image["layers"][0]
|
| 288 |
-
|
| 289 |
tar_image = base_image["background"]
|
| 290 |
tar_mask = base_mask["background"]
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
if ref_mask_option == "Draw Mask":
|
| 293 |
ref_image = reference_image["background"]
|
|
@@ -295,9 +298,9 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
| 295 |
elif ref_mask_option == "Upload with Mask":
|
| 296 |
ref_image = reference_image["background"]
|
| 297 |
ref_mask = ref_mask["background"]
|
| 298 |
-
else:
|
| 299 |
ref_image = reference_image["background"]
|
| 300 |
-
ref_mask = get_mask(ref_image,
|
| 301 |
|
| 302 |
tar_image = tar_image.convert("RGB")
|
| 303 |
tar_mask = tar_mask.convert("L")
|
|
@@ -393,10 +396,16 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
| 393 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
| 394 |
edited_image = Image.fromarray(edited_image)
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
|
| 402 |
def update_ui(option):
|
|
@@ -409,6 +418,7 @@ def update_ui(option):
|
|
| 409 |
with gr.Blocks() as demo:
|
| 410 |
gr.Markdown("# Insert-Anything")
|
| 411 |
gr.Markdown("### Make sure to select the correct mask button!!")
|
|
|
|
| 412 |
gr.Markdown("### Click the output image to toggle between Diptych and final results!!")
|
| 413 |
|
| 414 |
with gr.Row():
|
|
@@ -420,8 +430,11 @@ with gr.Blocks() as demo:
|
|
| 420 |
base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
|
| 421 |
layers=False, brush=False, eraser=False)
|
| 422 |
with gr.Row():
|
| 423 |
-
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option",
|
| 424 |
value="Upload with Mask")
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
with gr.Row():
|
| 427 |
ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
|
|
@@ -434,8 +447,8 @@ with gr.Blocks() as demo:
|
|
| 434 |
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
|
| 435 |
label="Reference Mask Input Option", value="Upload with Mask")
|
| 436 |
with gr.Row():
|
| 437 |
-
|
| 438 |
-
|
| 439 |
|
| 440 |
with gr.Column(scale=1):
|
| 441 |
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
|
|
@@ -446,6 +459,7 @@ with gr.Blocks() as demo:
|
|
| 446 |
gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
|
| 447 |
gr.Markdown(" Upload with Mask means uploading a mask file.")
|
| 448 |
gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
|
|
|
|
| 449 |
|
| 450 |
run_local_button = gr.Button(value="Run")
|
| 451 |
|
|
@@ -468,7 +482,7 @@ with gr.Blocks() as demo:
|
|
| 468 |
|
| 469 |
run_local_button.click(
|
| 470 |
fn=run_local,
|
| 471 |
-
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option,
|
| 472 |
outputs=[baseline_gallery]
|
| 473 |
)
|
| 474 |
demo.launch()
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
@spaces.GPU
|
| 284 |
+
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt):
|
| 285 |
if base_mask_option == "Draw Mask":
|
| 286 |
tar_image = base_image["background"]
|
| 287 |
tar_mask = base_image["layers"][0]
|
| 288 |
+
elif base_mask_option == "Upload with Mask":
|
| 289 |
tar_image = base_image["background"]
|
| 290 |
tar_mask = base_mask["background"]
|
| 291 |
+
else: # Label to Mask
|
| 292 |
+
tar_image = base_image["background"]
|
| 293 |
+
tar_mask = get_mask(tar_image, base_text_prompt)
|
| 294 |
|
| 295 |
if ref_mask_option == "Draw Mask":
|
| 296 |
ref_image = reference_image["background"]
|
|
|
|
| 298 |
elif ref_mask_option == "Upload with Mask":
|
| 299 |
ref_image = reference_image["background"]
|
| 300 |
ref_mask = ref_mask["background"]
|
| 301 |
+
else: # Label to Mask
|
| 302 |
ref_image = reference_image["background"]
|
| 303 |
+
ref_mask = get_mask(ref_image, ref_text_prompt)
|
| 304 |
|
| 305 |
tar_image = tar_image.convert("RGB")
|
| 306 |
tar_mask = tar_mask.convert("L")
|
|
|
|
| 396 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
| 397 |
edited_image = Image.fromarray(edited_image)
|
| 398 |
|
| 399 |
+
# Determine which masks to show as "generated" in output
|
| 400 |
+
masks_to_return = []
|
| 401 |
+
if base_mask_option == "Label to Mask":
|
| 402 |
+
masks_to_return.append(received_tar_mask) # Show generated background mask
|
| 403 |
+
if ref_mask_option == "Label to Mask":
|
| 404 |
+
masks_to_return.append(received_ref_mask) # Show generated reference mask
|
| 405 |
+
|
| 406 |
+
# Build return list: generated_masks + diptych + final_image + received_masks
|
| 407 |
+
return_list = masks_to_return + [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
|
| 408 |
+
return return_list
|
| 409 |
|
| 410 |
|
| 411 |
def update_ui(option):
|
|
|
|
| 418 |
with gr.Blocks() as demo:
|
| 419 |
gr.Markdown("# Insert-Anything")
|
| 420 |
gr.Markdown("### Make sure to select the correct mask button!!")
|
| 421 |
+
gr.Markdown("### Both background and reference images support automatic mask generation from text labels!!")
|
| 422 |
gr.Markdown("### Click the output image to toggle between Diptych and final results!!")
|
| 423 |
|
| 424 |
with gr.Row():
|
|
|
|
| 430 |
base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
|
| 431 |
layers=False, brush=False, eraser=False)
|
| 432 |
with gr.Row():
|
| 433 |
+
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Background Mask Input Option",
|
| 434 |
value="Upload with Mask")
|
| 435 |
+
with gr.Row():
|
| 436 |
+
base_text_prompt = gr.Textbox(label="Background Label",
|
| 437 |
+
placeholder="Enter the category to mask in background, e.g., sofa, table, person, etc.")
|
| 438 |
|
| 439 |
with gr.Row():
|
| 440 |
ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
|
|
|
|
| 447 |
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
|
| 448 |
label="Reference Mask Input Option", value="Upload with Mask")
|
| 449 |
with gr.Row():
|
| 450 |
+
ref_text_prompt = gr.Textbox(label="Reference Label",
|
| 451 |
+
placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
|
| 452 |
|
| 453 |
with gr.Column(scale=1):
|
| 454 |
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
|
|
|
|
| 459 |
gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
|
| 460 |
gr.Markdown(" Upload with Mask means uploading a mask file.")
|
| 461 |
gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
|
| 462 |
+
gr.Markdown(" Both background and reference images now support all three masking options including automatic mask generation from labels.")
|
| 463 |
|
| 464 |
run_local_button = gr.Button(value="Run")
|
| 465 |
|
|
|
|
| 482 |
|
| 483 |
run_local_button.click(
|
| 484 |
fn=run_local,
|
| 485 |
+
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, base_text_prompt, ref_text_prompt],
|
| 486 |
outputs=[baseline_gallery]
|
| 487 |
)
|
| 488 |
demo.launch()
|