Spaces:
Build error
Build error
[Major] Add list generation
Browse files
app.py
CHANGED
|
@@ -282,22 +282,138 @@ def generate(
|
|
| 282 |
|
| 283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
def reset():
|
| 286 |
return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
|
| 287 |
|
|
|
|
| 288 |
def get_example():
|
| 289 |
return [
|
| 290 |
-
["example_images/dufu.png", "black and white suit", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 291 |
-
["example_images/girl.jpeg", "reflective sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 292 |
-
["example_images/
|
| 293 |
-
["example_images/
|
| 294 |
-
["example_images/
|
| 295 |
-
["example_images/
|
| 296 |
-
["example_images/
|
| 297 |
-
["example_images/girl.jpeg", "
|
| 298 |
-
["example_images/
|
| 299 |
-
["example_images/girl.jpeg", "
|
| 300 |
-
["example_images/
|
|
|
|
|
|
|
| 301 |
]
|
| 302 |
|
| 303 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
@@ -325,7 +441,14 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 325 |
with gr.Row():
|
| 326 |
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
| 327 |
with gr.Row():
|
| 328 |
-
instruction = gr.Textbox(lines=1, label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
with gr.Row():
|
| 330 |
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
|
| 331 |
randomize_seed = gr.Radio(
|
|
@@ -347,17 +470,13 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 347 |
)
|
| 348 |
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
| 349 |
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
| 350 |
-
with gr.Row():
|
| 351 |
-
reset_button = gr.Button("Reset")
|
| 352 |
-
generate_button = gr.Button("Generate")
|
| 353 |
with gr.Column(scale=1, min_width=100):
|
| 354 |
with gr.Column():
|
| 355 |
mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
|
| 356 |
with gr.Column():
|
| 357 |
edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
|
| 358 |
|
| 359 |
-
|
| 360 |
-
with gr.Accordion('More outputs', open=False):
|
| 361 |
with gr.Row():
|
| 362 |
weather_close_video = gr.Radio(
|
| 363 |
["Show Image Video", "Close Image Video"],
|
|
@@ -374,15 +493,11 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 374 |
original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
|
| 375 |
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
|
| 376 |
mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
|
| 377 |
-
|
| 378 |
|
| 379 |
with gr.Row():
|
| 380 |
gr.Examples(
|
| 381 |
examples=get_example(),
|
| 382 |
-
|
| 383 |
-
inputs=[input_image, instruction, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
|
| 384 |
-
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
| 385 |
-
cache_examples=False,
|
| 386 |
)
|
| 387 |
|
| 388 |
generate_button.click(
|
|
@@ -401,6 +516,24 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 401 |
],
|
| 402 |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
| 403 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
reset_button.click(
|
| 405 |
fn=reset,
|
| 406 |
inputs=[],
|
|
|
|
| 282 |
|
| 283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
| 284 |
|
| 285 |
+
@spaces.GPU(duration=30)
|
| 286 |
+
def generate_list(
|
| 287 |
+
input_image: Image.Image,
|
| 288 |
+
generate_list: str,
|
| 289 |
+
steps: int,
|
| 290 |
+
randomize_seed: bool,
|
| 291 |
+
seed: int,
|
| 292 |
+
randomize_cfg: bool,
|
| 293 |
+
text_cfg_scale: float,
|
| 294 |
+
image_cfg_scale: float,
|
| 295 |
+
weather_close_video: bool,
|
| 296 |
+
decode_image_batch: int
|
| 297 |
+
):
|
| 298 |
+
generate_list = generate_list.split('\n')
|
| 299 |
+
# Remove the empty element
|
| 300 |
+
generate_list = [element for element in generate_list if element]
|
| 301 |
+
|
| 302 |
+
seed = random.randint(0, 100000) if randomize_seed else seed
|
| 303 |
+
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
|
| 304 |
+
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
|
| 305 |
+
|
| 306 |
+
width, height = input_image.size
|
| 307 |
+
factor = args.resolution / max(width, height)
|
| 308 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
| 309 |
+
width = int((width * factor) // 64) * 64
|
| 310 |
+
height = int((height * factor) // 64) * 64
|
| 311 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
| 312 |
+
|
| 313 |
+
if len(generate_list) == 0:
|
| 314 |
+
return [input_image, seed]
|
| 315 |
+
|
| 316 |
+
model.cuda()
|
| 317 |
+
image_video = [np.array(input_image).astype(np.uint8)]
|
| 318 |
+
generate_index = 0
|
| 319 |
+
input_image_copy = input_image.convert("RGB")
|
| 320 |
+
while generate_index < len(generate_list):
|
| 321 |
+
print(f'generate_index: {str(generate_index)}')
|
| 322 |
+
instruction = generate_list[generate_index]
|
| 323 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
| 324 |
+
cond = {}
|
| 325 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy.copy())).float() / 255 - 1
|
| 326 |
+
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
| 327 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
| 328 |
+
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
| 329 |
+
|
| 330 |
+
uncond = {}
|
| 331 |
+
uncond["c_crossattn"] = [null_token.to(model.device)]
|
| 332 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
| 333 |
+
|
| 334 |
+
sigmas = model_wrap.get_sigmas(steps).to(model.device)
|
| 335 |
+
|
| 336 |
+
extra_args = {
|
| 337 |
+
"cond": cond,
|
| 338 |
+
"uncond": uncond,
|
| 339 |
+
"text_cfg_scale": text_cfg_scale,
|
| 340 |
+
"image_cfg_scale": image_cfg_scale,
|
| 341 |
+
}
|
| 342 |
+
torch.manual_seed(seed)
|
| 343 |
+
z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
| 344 |
+
z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
| 345 |
+
|
| 346 |
+
z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
|
| 347 |
+
|
| 348 |
+
x_0 = model.decode_first_stage(z_0)
|
| 349 |
+
|
| 350 |
+
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
| 351 |
+
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
| 352 |
+
|
| 353 |
+
if torch.sum(x_1).item()/x_1.numel() < -0.99:
|
| 354 |
+
seed += 1
|
| 355 |
+
continue
|
| 356 |
+
else:
|
| 357 |
+
generate_index += 1
|
| 358 |
+
|
| 359 |
+
x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0)
|
| 360 |
+
x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0)
|
| 361 |
+
x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c")
|
| 362 |
+
x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c")
|
| 363 |
+
x_1 = torch.cat([x_1, x_1, x_1], dim=-1)
|
| 364 |
+
edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy())
|
| 365 |
+
edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy())
|
| 366 |
+
|
| 367 |
+
# 对edited_mask做膨胀
|
| 368 |
+
edited_mask_copy = edited_mask.copy()
|
| 369 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 370 |
+
edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3)
|
| 371 |
+
edited_mask = Image.fromarray(edited_mask)
|
| 372 |
+
|
| 373 |
+
m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3))
|
| 374 |
+
m_img = np.asarray(m_img).astype('float') / 255.0
|
| 375 |
+
img_np = np.asarray(input_image_copy).astype('float') / 255.0
|
| 376 |
+
ours_np = np.asarray(edited_image).astype('float') / 255.0
|
| 377 |
+
|
| 378 |
+
mix_image_np = m_img * ours_np + (1 - m_img) * img_np
|
| 379 |
+
|
| 380 |
+
image_video.append((mix_image_np * 255).astype(np.uint8))
|
| 381 |
+
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
| 382 |
+
input_image_copy = mix_image
|
| 383 |
+
|
| 384 |
+
mix_result_with_red_mask = None
|
| 385 |
+
mask_video_path = None
|
| 386 |
+
edited_mask_copy = None
|
| 387 |
+
|
| 388 |
+
image_video_path = "image.mp4"
|
| 389 |
+
fps = 2
|
| 390 |
+
with imageio.get_writer(image_video_path, fps=fps) as video:
|
| 391 |
+
for image in image_video:
|
| 392 |
+
video.append_data(image)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
| 396 |
+
|
| 397 |
+
|
| 398 |
def reset():
|
| 399 |
return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
|
| 400 |
|
| 401 |
+
|
| 402 |
def get_example():
|
| 403 |
return [
|
| 404 |
+
["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 405 |
+
["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 406 |
+
["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 407 |
+
["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 408 |
+
["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 409 |
+
["example_images/dufu.png", "blue medical mask", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 410 |
+
["example_images/people_standing.png", "dark green pleated skirt", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 411 |
+
["example_images/girl.jpeg", "shiny golden crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 412 |
+
["example_images/dufu.png", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 413 |
+
["example_images/girl.jpeg", "diamond necklace", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 414 |
+
["example_images/iron_man.jpg", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 415 |
+
["example_images/girl.jpeg", "the queen's crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 416 |
+
["example_images/girl.jpeg", "gorgeous yellow gown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
| 417 |
]
|
| 418 |
|
| 419 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
|
|
| 441 |
with gr.Row():
|
| 442 |
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
| 443 |
with gr.Row():
|
| 444 |
+
instruction = gr.Textbox(lines=1, label="Single object description", interactive=True)
|
| 445 |
+
with gr.Row():
|
| 446 |
+
reset_button = gr.Button("Reset")
|
| 447 |
+
generate_button = gr.Button("Generate")
|
| 448 |
+
with gr.Row():
|
| 449 |
+
list_input = gr.Textbox(label="Input List", placeholder="Enter one item per line", lines=10)
|
| 450 |
+
with gr.Row():
|
| 451 |
+
list_generate_button = gr.Button("List Generate")
|
| 452 |
with gr.Row():
|
| 453 |
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
|
| 454 |
randomize_seed = gr.Radio(
|
|
|
|
| 470 |
)
|
| 471 |
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
| 472 |
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
|
|
|
|
|
|
|
|
|
| 473 |
with gr.Column(scale=1, min_width=100):
|
| 474 |
with gr.Column():
|
| 475 |
mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
|
| 476 |
with gr.Column():
|
| 477 |
edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
|
| 478 |
|
| 479 |
+
with gr.Accordion('Click to see more (includes generation process per object for list generation and per step for single generation)', open=False):
|
|
|
|
| 480 |
with gr.Row():
|
| 481 |
weather_close_video = gr.Radio(
|
| 482 |
["Show Image Video", "Close Image Video"],
|
|
|
|
| 493 |
original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
|
| 494 |
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
|
| 495 |
mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
|
|
|
|
| 496 |
|
| 497 |
with gr.Row():
|
| 498 |
gr.Examples(
|
| 499 |
examples=get_example(),
|
| 500 |
+
inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
|
|
|
|
|
|
|
|
|
|
| 501 |
)
|
| 502 |
|
| 503 |
generate_button.click(
|
|
|
|
| 516 |
],
|
| 517 |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
| 518 |
)
|
| 519 |
+
|
| 520 |
+
list_generate_button.click(
|
| 521 |
+
fn=generate_list,
|
| 522 |
+
inputs=[
|
| 523 |
+
input_image,
|
| 524 |
+
list_input,
|
| 525 |
+
steps,
|
| 526 |
+
randomize_seed,
|
| 527 |
+
seed,
|
| 528 |
+
randomize_cfg,
|
| 529 |
+
text_cfg_scale,
|
| 530 |
+
image_cfg_scale,
|
| 531 |
+
weather_close_video,
|
| 532 |
+
decode_image_batch
|
| 533 |
+
],
|
| 534 |
+
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
reset_button.click(
|
| 538 |
fn=reset,
|
| 539 |
inputs=[],
|