Spaces:
Running
on
Zero
Running
on
Zero
test
Browse files- __pycache__/dreamfuse_inference.cpython-310.pyc +0 -0
- app.py +25 -35
- dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc +0 -0
- dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc +0 -0
- dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc +0 -0
- dreamfuse_inference.py +10 -83
__pycache__/dreamfuse_inference.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/dreamfuse_inference.cpython-310.pyc and b/__pycache__/dreamfuse_inference.cpython-310.pyc differ
|
|
|
app.py
CHANGED
|
@@ -400,36 +400,27 @@ class DreamblendGUI:
|
|
| 400 |
canvas_size=400
|
| 401 |
), draggable_img
|
| 402 |
|
| 403 |
-
def save_image(self, save_path = "/mnt/bn/hjj-humanseg-lq/SubjectDriven/DreamFuse/debug"):
|
| 404 |
-
global generated_images
|
| 405 |
-
save_name = self.get_next_sequence(save_path)
|
| 406 |
-
generated_images[0].save(os.path.join(save_path, f"{save_name}_0_ori.png"))
|
| 407 |
-
generated_images[1].save(os.path.join(save_path, f"{save_name}_0.png"))
|
| 408 |
-
generated_images[2].save(os.path.join(save_path, f"{save_name}_1.png"))
|
| 409 |
-
generated_images[3].save(os.path.join(save_path, f"{save_name}_2.png"))
|
| 410 |
-
generated_images[4].save(os.path.join(save_path, f"{save_name}_0_mask.png"))
|
| 411 |
-
generated_images[5].save(os.path.join(save_path, f"{save_name}_0_mask_scale.png"))
|
| 412 |
-
generated_images[6].save(os.path.join(save_path, f"{save_name}_0_scale.png"))
|
| 413 |
-
generated_images[7].save(os.path.join(save_path, f"{save_name}_2_pasted.png"))
|
| 414 |
-
|
| 415 |
|
| 416 |
def create_gui(self):
|
| 417 |
config = InferenceConfig()
|
| 418 |
config.lora_id = 'LL3RD/DreamFuse'
|
| 419 |
|
| 420 |
-
pipeline =
|
| 421 |
-
pipeline
|
|
|
|
| 422 |
"""创建 Gradio 界面"""
|
| 423 |
with gr.Blocks(css=self.css_style) as demo:
|
| 424 |
modified_fg_state = gr.State()
|
| 425 |
-
gr.Markdown("#
|
| 426 |
-
gr.Markdown("
|
|
|
|
|
|
|
| 427 |
with gr.Row():
|
| 428 |
with gr.Column(scale=1):
|
| 429 |
-
gr.Markdown("###
|
| 430 |
-
background_img_in = gr.Image(label="
|
| 431 |
-
draggable_img_in = gr.Image(label="
|
| 432 |
-
generate_btn = gr.Button("
|
| 433 |
|
| 434 |
with gr.Row():
|
| 435 |
gr.Examples(
|
|
@@ -438,39 +429,38 @@ class DreamblendGUI:
|
|
| 438 |
elem_id="small-examples"
|
| 439 |
)
|
| 440 |
with gr.Column(scale=1):
|
| 441 |
-
gr.Markdown("###
|
| 442 |
-
html_out = gr.HTML(label="
|
| 443 |
|
| 444 |
with gr.Row():
|
| 445 |
with gr.Column(scale=1):
|
| 446 |
-
gr.Markdown("###
|
| 447 |
-
seed_slider = gr.Slider(minimum
|
| 448 |
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
|
| 449 |
size_select = gr.Radio(
|
| 450 |
choices=["512", "768", "1024"],
|
| 451 |
value="512",
|
| 452 |
label="生成质量(512-差 1024-好)",
|
| 453 |
)
|
| 454 |
-
prompt_text = gr.Textbox(label="Prompt", placeholder="
|
| 455 |
-
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1)
|
| 456 |
-
enable_gui = gr.Checkbox(label="启用GUI", value=True)
|
| 457 |
-
enable_truecfg = gr.Checkbox(label="
|
| 458 |
-
enable_save = gr.Button("保存图片 (内部测试)", visible=True)
|
| 459 |
with gr.Column(scale=1):
|
| 460 |
-
gr.Markdown("###
|
| 461 |
-
model_generate_btn = gr.Button("
|
| 462 |
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
|
| 463 |
-
model_output = gr.Image(label="
|
| 464 |
|
| 465 |
-
|
| 466 |
-
enable_save.click(fn=self.save_image, inputs=None, outputs=None)
|
| 467 |
generate_btn.click(
|
| 468 |
fn=self.on_upload,
|
| 469 |
inputs=[background_img_in, draggable_img_in],
|
| 470 |
outputs=[html_out, modified_fg_state],
|
| 471 |
)
|
| 472 |
model_generate_btn.click(
|
| 473 |
-
fn=pipeline.gradio_generate,
|
|
|
|
| 474 |
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
|
| 475 |
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
|
| 476 |
outputs=model_output
|
|
|
|
| 400 |
canvas_size=400
|
| 401 |
), draggable_img
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
def create_gui(self):
|
| 405 |
config = InferenceConfig()
|
| 406 |
config.lora_id = 'LL3RD/DreamFuse'
|
| 407 |
|
| 408 |
+
pipeline = None
|
| 409 |
+
# pipeline = DreamFuseInference(config)
|
| 410 |
+
# pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
|
| 411 |
"""创建 Gradio 界面"""
|
| 412 |
with gr.Blocks(css=self.css_style) as demo:
|
| 413 |
modified_fg_state = gr.State()
|
| 414 |
+
gr.Markdown("# DreamFuse: 3 Easy Steps to Create Your Fusion Image")
|
| 415 |
+
gr.Markdown("1. Upload the foreground and background images you want to blend.")
|
| 416 |
+
gr.Markdown("2. Click 'Generate Canvas' to preview the result. You can then drag and resize the foreground object to position it as you like.")
|
| 417 |
+
gr.Markdown("3. Click 'Run Model' to create the final fused image.")
|
| 418 |
with gr.Row():
|
| 419 |
with gr.Column(scale=1):
|
| 420 |
+
gr.Markdown("### FG&BG Image Upload")
|
| 421 |
+
background_img_in = gr.Image(label="Background Image", type="pil", height=240, width=240)
|
| 422 |
+
draggable_img_in = gr.Image(label="Foreground Image", type="pil", image_mode="RGBA", height=240, width=240)
|
| 423 |
+
generate_btn = gr.Button("Generate Canvas")
|
| 424 |
|
| 425 |
with gr.Row():
|
| 426 |
gr.Examples(
|
|
|
|
| 429 |
elem_id="small-examples"
|
| 430 |
)
|
| 431 |
with gr.Column(scale=1):
|
| 432 |
+
gr.Markdown("### Preview Region")
|
| 433 |
+
html_out = gr.HTML(label="drag and resize", elem_id="canvas_preview")
|
| 434 |
|
| 435 |
with gr.Row():
|
| 436 |
with gr.Column(scale=1):
|
| 437 |
+
gr.Markdown("### Parameters")
|
| 438 |
+
seed_slider = gr.Slider(minimum=-1, maximum=100000, step=1, label="Seed", value=12345)
|
| 439 |
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
|
| 440 |
size_select = gr.Radio(
|
| 441 |
choices=["512", "768", "1024"],
|
| 442 |
value="512",
|
| 443 |
label="生成质量(512-差 1024-好)",
|
| 444 |
)
|
| 445 |
+
prompt_text = gr.Textbox(label="Prompt", placeholder="text prompt", value="")
|
| 446 |
+
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1, visible=False)
|
| 447 |
+
enable_gui = gr.Checkbox(label="启用GUI", value=True, visible=False)
|
| 448 |
+
enable_truecfg = gr.Checkbox(label="TrueCFG", value=False, visible=False)
|
|
|
|
| 449 |
with gr.Column(scale=1):
|
| 450 |
+
gr.Markdown("### Model Result")
|
| 451 |
+
model_generate_btn = gr.Button("Run Model")
|
| 452 |
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
|
| 453 |
+
model_output = gr.Image(label="Model Output", type="pil")
|
| 454 |
|
| 455 |
+
|
|
|
|
| 456 |
generate_btn.click(
|
| 457 |
fn=self.on_upload,
|
| 458 |
inputs=[background_img_in, draggable_img_in],
|
| 459 |
outputs=[html_out, modified_fg_state],
|
| 460 |
)
|
| 461 |
model_generate_btn.click(
|
| 462 |
+
# fn=pipeline.gradio_generate,
|
| 463 |
+
fn=self.pil_to_base64,
|
| 464 |
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
|
| 465 |
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
|
| 466 |
outputs=model_output
|
dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc
CHANGED
|
Binary files a/dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc and b/dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc differ
|
|
|
dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc
CHANGED
|
Binary files a/dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc and b/dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc differ
|
|
|
dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc
CHANGED
|
Binary files a/dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc and b/dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc differ
|
|
|
dreamfuse_inference.py
CHANGED
|
@@ -168,7 +168,6 @@ def make_image_grid(images, rows, cols, size=None):
|
|
| 168 |
class DreamFuseInference:
|
| 169 |
def __init__(self, config: InferenceConfig):
|
| 170 |
self.config = config
|
| 171 |
-
print(config.device)
|
| 172 |
self.device = torch.device(config.device)
|
| 173 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 174 |
seed_everything(config.seed)
|
|
@@ -348,16 +347,15 @@ class DreamFuseInference:
|
|
| 348 |
|
| 349 |
@torch.inference_mode()
|
| 350 |
def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False):
|
| 351 |
-
print("!"*10)
|
| 352 |
-
"""使用 DreamFuseInference 进行模型推理"""
|
| 353 |
try:
|
| 354 |
trans = json.loads(transformation_info)
|
| 355 |
except:
|
| 356 |
trans = {}
|
| 357 |
|
| 358 |
size_select = int(size_select)
|
|
|
|
|
|
|
| 359 |
|
| 360 |
-
# import pdb; pdb.set_trace()
|
| 361 |
r, g, b, ori_a = foreground_img.split()
|
| 362 |
fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans)
|
| 363 |
|
|
@@ -370,9 +368,7 @@ class DreamFuseInference:
|
|
| 370 |
ori_a = ori_a.convert("L")
|
| 371 |
new_a = new_a.convert("L")
|
| 372 |
foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a))
|
| 373 |
-
|
| 374 |
-
print(foreground_img.size)
|
| 375 |
-
print(background_img.size)
|
| 376 |
images = self.model_generate(foreground_img.copy(), background_img.copy(),
|
| 377 |
ori_a, new_a,
|
| 378 |
enable_mask_affine=enable_gui,
|
|
@@ -386,16 +382,15 @@ class DreamFuseInference:
|
|
| 386 |
images = Image.fromarray(images[0], "RGB")
|
| 387 |
|
| 388 |
images = images.resize(background_img.size)
|
| 389 |
-
images_save = images.copy()
|
| 390 |
|
| 391 |
-
images.thumbnail((640, 640), Image.LANCZOS)
|
| 392 |
return images
|
| 393 |
|
| 394 |
|
| 395 |
@torch.inference_mode()
|
| 396 |
def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False):
|
| 397 |
batch_size = 1
|
| 398 |
-
print("-3"*10)
|
| 399 |
# Prepare images
|
| 400 |
# adjust bg->fg size
|
| 401 |
fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size)
|
|
@@ -410,7 +405,6 @@ class DreamFuseInference:
|
|
| 410 |
new_fg_mask = new_fg_mask.resize(bucket_size)
|
| 411 |
mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask)
|
| 412 |
|
| 413 |
-
print("-2"*10)
|
| 414 |
# Get embeddings
|
| 415 |
prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt)
|
| 416 |
|
|
@@ -428,7 +422,6 @@ class DreamFuseInference:
|
|
| 428 |
if seed is None:
|
| 429 |
seed = self.config.seed
|
| 430 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 431 |
-
print("-1"*10)
|
| 432 |
# Prepare condition latents
|
| 433 |
condition_image_latents = self._encode_images([fg_image, bg_image])
|
| 434 |
|
|
@@ -445,7 +438,6 @@ class DreamFuseInference:
|
|
| 445 |
)
|
| 446 |
)
|
| 447 |
|
| 448 |
-
print(1)
|
| 449 |
if mask_affine is not None:
|
| 450 |
affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2
|
| 451 |
scale_factor = 1 / 16
|
|
@@ -457,7 +449,7 @@ class DreamFuseInference:
|
|
| 457 |
scale_factor=scale_factor, device=self.device,
|
| 458 |
)
|
| 459 |
cond_latent_image_ids = torch.stack(cond_latent_image_ids)
|
| 460 |
-
|
| 461 |
# Pack condition latents
|
| 462 |
cond_image_latents = self._pack_latents(condition_image_latents)
|
| 463 |
cond_input = {
|
|
@@ -470,7 +462,7 @@ class DreamFuseInference:
|
|
| 470 |
latents, latent_image_ids = self._prepare_latents(
|
| 471 |
batch_size, num_channels_latents, height, width, generator
|
| 472 |
)
|
| 473 |
-
|
| 474 |
# Setup timesteps
|
| 475 |
sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps)
|
| 476 |
image_seq_len = latents.shape[1]
|
|
@@ -488,7 +480,7 @@ class DreamFuseInference:
|
|
| 488 |
sigmas=sigmas,
|
| 489 |
mu=mu,
|
| 490 |
)
|
| 491 |
-
|
| 492 |
# Denoising loop
|
| 493 |
for i, t in enumerate(timesteps):
|
| 494 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
|
@@ -537,12 +529,12 @@ class DreamFuseInference:
|
|
| 537 |
|
| 538 |
# Compute previous noisy sample
|
| 539 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 540 |
-
|
| 541 |
# Decode latents
|
| 542 |
latents = self._unpack_latents(latents, height, width)
|
| 543 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 544 |
images = self.vae.decode(latents, return_dict=False)[0]
|
| 545 |
-
|
| 546 |
# Post-process images
|
| 547 |
images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
| 548 |
return images
|
|
@@ -575,68 +567,3 @@ class DreamFuseInference:
|
|
| 575 |
offset=None
|
| 576 |
)
|
| 577 |
return latents, latent_image_ids
|
| 578 |
-
|
| 579 |
-
def main():
|
| 580 |
-
parser = transformers.HfArgumentParser(InferenceConfig)
|
| 581 |
-
config: InferenceConfig = parser.parse_args_into_dataclasses()[0]
|
| 582 |
-
model = DreamFuseInference(config)
|
| 583 |
-
os.makedirs(config.valid_output_dir, exist_ok=True)
|
| 584 |
-
for valid_root, valid_json in zip(config.valid_roots, config.valid_jsons):
|
| 585 |
-
with open(valid_json, 'r') as f:
|
| 586 |
-
valid_info = json.load(f)
|
| 587 |
-
|
| 588 |
-
# multi gpu
|
| 589 |
-
to_process = sorted(list(valid_info.keys()))
|
| 590 |
-
|
| 591 |
-
# debug
|
| 592 |
-
to_process = [k for k in to_process if "data_wear" in k and "pixelwave" in k]
|
| 593 |
-
# debug
|
| 594 |
-
|
| 595 |
-
sd_idx = len(to_process) // config.total_num * config.sub_idx
|
| 596 |
-
ed_idx = len(to_process) // config.total_num * (config.sub_idx + 1)
|
| 597 |
-
if config.sub_idx < config.total_num - 1:
|
| 598 |
-
print(config.sub_idx, sd_idx, ed_idx)
|
| 599 |
-
to_process = to_process[sd_idx:ed_idx]
|
| 600 |
-
else:
|
| 601 |
-
print(config.sub_idx, sd_idx)
|
| 602 |
-
to_process = to_process[sd_idx:]
|
| 603 |
-
valid_info = {k: valid_info[k] for k in to_process}
|
| 604 |
-
|
| 605 |
-
for meta_key, info in tqdm(valid_info.items()):
|
| 606 |
-
img_name = meta_key.split('/')[-1]
|
| 607 |
-
|
| 608 |
-
foreground_img = Image.open(os.path.join(valid_root, info['img_info']['000']))
|
| 609 |
-
background_img = Image.open(os.path.join(valid_root, info['img_info']['001']))
|
| 610 |
-
|
| 611 |
-
new_fg_mask = Image.open(os.path.join(valid_root, info['img_mask_info']['000_mask_scale']))
|
| 612 |
-
ori_fg_mask = Image.open(os.path.join(valid_root, info['img_mask_info']['000']))
|
| 613 |
-
|
| 614 |
-
# debug
|
| 615 |
-
foreground_img.save(os.path.join(config.valid_output_dir, f"{img_name}_0.png"))
|
| 616 |
-
background_img.save(os.path.join(config.valid_output_dir, f"{img_name}_1.png"))
|
| 617 |
-
ori_fg_mask.save(os.path.join(config.valid_output_dir, f"{img_name}_0_mask.png"))
|
| 618 |
-
new_fg_mask.save(os.path.join(config.valid_output_dir, f"{img_name}_0_mask_scale.png"))
|
| 619 |
-
# debug
|
| 620 |
-
|
| 621 |
-
foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_fg_mask))
|
| 622 |
-
|
| 623 |
-
images = model(foreground_img.copy(), background_img.copy(),
|
| 624 |
-
ori_fg_mask, new_fg_mask,
|
| 625 |
-
prompt=config.ref_prompts,
|
| 626 |
-
seed=config.seed,
|
| 627 |
-
cfg=config.guidance_scale,
|
| 628 |
-
size_select=config.inference_scale,
|
| 629 |
-
text_strength=config.text_strength,
|
| 630 |
-
truecfg=config.truecfg)
|
| 631 |
-
|
| 632 |
-
result_image = Image.fromarray(images[0], "RGB")
|
| 633 |
-
result_image = result_image.resize(background_img.size)
|
| 634 |
-
result_image.save(os.path.join(config.valid_output_dir, f"{img_name}_2.png"))
|
| 635 |
-
# Make grid
|
| 636 |
-
grid_image = [foreground_img, background_img] + [result_image]
|
| 637 |
-
result = make_image_grid(grid_image, 1, len(grid_image), size=result_image.size)
|
| 638 |
-
|
| 639 |
-
result.save(os.path.join(config.valid_output_dir, f"{img_name}.jpg"))
|
| 640 |
-
|
| 641 |
-
if __name__ == "__main__":
|
| 642 |
-
main()
|
|
|
|
| 168 |
class DreamFuseInference:
|
| 169 |
def __init__(self, config: InferenceConfig):
|
| 170 |
self.config = config
|
|
|
|
| 171 |
self.device = torch.device(config.device)
|
| 172 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 173 |
seed_everything(config.seed)
|
|
|
|
| 347 |
|
| 348 |
@torch.inference_mode()
|
| 349 |
def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False):
|
|
|
|
|
|
|
| 350 |
try:
|
| 351 |
trans = json.loads(transformation_info)
|
| 352 |
except:
|
| 353 |
trans = {}
|
| 354 |
|
| 355 |
size_select = int(size_select)
|
| 356 |
+
if size_select == 1024: text_strength = 5
|
| 357 |
+
if size_select == 768: text_strength = 3
|
| 358 |
|
|
|
|
| 359 |
r, g, b, ori_a = foreground_img.split()
|
| 360 |
fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans)
|
| 361 |
|
|
|
|
| 368 |
ori_a = ori_a.convert("L")
|
| 369 |
new_a = new_a.convert("L")
|
| 370 |
foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a))
|
| 371 |
+
|
|
|
|
|
|
|
| 372 |
images = self.model_generate(foreground_img.copy(), background_img.copy(),
|
| 373 |
ori_a, new_a,
|
| 374 |
enable_mask_affine=enable_gui,
|
|
|
|
| 382 |
images = Image.fromarray(images[0], "RGB")
|
| 383 |
|
| 384 |
images = images.resize(background_img.size)
|
| 385 |
+
# images_save = images.copy()
|
| 386 |
|
| 387 |
+
# images.thumbnail((640, 640), Image.LANCZOS)
|
| 388 |
return images
|
| 389 |
|
| 390 |
|
| 391 |
@torch.inference_mode()
|
| 392 |
def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False):
|
| 393 |
batch_size = 1
|
|
|
|
| 394 |
# Prepare images
|
| 395 |
# adjust bg->fg size
|
| 396 |
fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size)
|
|
|
|
| 405 |
new_fg_mask = new_fg_mask.resize(bucket_size)
|
| 406 |
mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask)
|
| 407 |
|
|
|
|
| 408 |
# Get embeddings
|
| 409 |
prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt)
|
| 410 |
|
|
|
|
| 422 |
if seed is None:
|
| 423 |
seed = self.config.seed
|
| 424 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
|
|
| 425 |
# Prepare condition latents
|
| 426 |
condition_image_latents = self._encode_images([fg_image, bg_image])
|
| 427 |
|
|
|
|
| 438 |
)
|
| 439 |
)
|
| 440 |
|
|
|
|
| 441 |
if mask_affine is not None:
|
| 442 |
affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2
|
| 443 |
scale_factor = 1 / 16
|
|
|
|
| 449 |
scale_factor=scale_factor, device=self.device,
|
| 450 |
)
|
| 451 |
cond_latent_image_ids = torch.stack(cond_latent_image_ids)
|
| 452 |
+
|
| 453 |
# Pack condition latents
|
| 454 |
cond_image_latents = self._pack_latents(condition_image_latents)
|
| 455 |
cond_input = {
|
|
|
|
| 462 |
latents, latent_image_ids = self._prepare_latents(
|
| 463 |
batch_size, num_channels_latents, height, width, generator
|
| 464 |
)
|
| 465 |
+
|
| 466 |
# Setup timesteps
|
| 467 |
sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps)
|
| 468 |
image_seq_len = latents.shape[1]
|
|
|
|
| 480 |
sigmas=sigmas,
|
| 481 |
mu=mu,
|
| 482 |
)
|
| 483 |
+
|
| 484 |
# Denoising loop
|
| 485 |
for i, t in enumerate(timesteps):
|
| 486 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
|
|
|
| 529 |
|
| 530 |
# Compute previous noisy sample
|
| 531 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 532 |
+
|
| 533 |
# Decode latents
|
| 534 |
latents = self._unpack_latents(latents, height, width)
|
| 535 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 536 |
images = self.vae.decode(latents, return_dict=False)[0]
|
| 537 |
+
|
| 538 |
# Post-process images
|
| 539 |
images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
| 540 |
return images
|
|
|
|
| 567 |
offset=None
|
| 568 |
)
|
| 569 |
return latents, latent_image_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|