LL3RD commited on
Commit
f6e3a92
·
1 Parent(s): f96f677
__pycache__/dreamfuse_inference.cpython-310.pyc CHANGED
Binary files a/__pycache__/dreamfuse_inference.cpython-310.pyc and b/__pycache__/dreamfuse_inference.cpython-310.pyc differ
 
app.py CHANGED
@@ -400,36 +400,27 @@ class DreamblendGUI:
400
  canvas_size=400
401
  ), draggable_img
402
 
403
- def save_image(self, save_path = "/mnt/bn/hjj-humanseg-lq/SubjectDriven/DreamFuse/debug"):
404
- global generated_images
405
- save_name = self.get_next_sequence(save_path)
406
- generated_images[0].save(os.path.join(save_path, f"{save_name}_0_ori.png"))
407
- generated_images[1].save(os.path.join(save_path, f"{save_name}_0.png"))
408
- generated_images[2].save(os.path.join(save_path, f"{save_name}_1.png"))
409
- generated_images[3].save(os.path.join(save_path, f"{save_name}_2.png"))
410
- generated_images[4].save(os.path.join(save_path, f"{save_name}_0_mask.png"))
411
- generated_images[5].save(os.path.join(save_path, f"{save_name}_0_mask_scale.png"))
412
- generated_images[6].save(os.path.join(save_path, f"{save_name}_0_scale.png"))
413
- generated_images[7].save(os.path.join(save_path, f"{save_name}_2_pasted.png"))
414
-
415
 
416
  def create_gui(self):
417
  config = InferenceConfig()
418
  config.lora_id = 'LL3RD/DreamFuse'
419
 
420
- pipeline = DreamFuseInference(config)
421
- pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
 
422
  """创建 Gradio 界面"""
423
  with gr.Blocks(css=self.css_style) as demo:
424
  modified_fg_state = gr.State()
425
- gr.Markdown("# Dreamblend-GUI-dirtydata")
426
- gr.Markdown("通过上传背景图与前景图生成带有可拖拽/缩放预览的合成图像,同时支持 Seed 设置和 Prompt 文本输入。")
 
 
427
  with gr.Row():
428
  with gr.Column(scale=1):
429
- gr.Markdown("### 上传图片")
430
- background_img_in = gr.Image(label="背景图片", type="pil", height=240, width=240)
431
- draggable_img_in = gr.Image(label="前景图片", type="pil", image_mode="RGBA", height=240, width=240)
432
- generate_btn = gr.Button("生成可拖拽画布")
433
 
434
  with gr.Row():
435
  gr.Examples(
@@ -438,39 +429,38 @@ class DreamblendGUI:
438
  elem_id="small-examples"
439
  )
440
  with gr.Column(scale=1):
441
- gr.Markdown("### 预览区域")
442
- html_out = gr.HTML(label="预览与拖拽", elem_id="canvas_preview")
443
 
444
  with gr.Row():
445
  with gr.Column(scale=1):
446
- gr.Markdown("### 参数设置")
447
- seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
448
  cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
449
  size_select = gr.Radio(
450
  choices=["512", "768", "1024"],
451
  value="512",
452
  label="生成质量(512-差 1024-好)",
453
  )
454
- prompt_text = gr.Textbox(label="Prompt", placeholder="输入文本提示", value="")
455
- text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1)
456
- enable_gui = gr.Checkbox(label="启用GUI", value=True)
457
- enable_truecfg = gr.Checkbox(label="启用TrueCFG", value=False)
458
- enable_save = gr.Button("保存图片 (内部测试)", visible=True)
459
  with gr.Column(scale=1):
460
- gr.Markdown("### 模型生成结果")
461
- model_generate_btn = gr.Button("模型生成")
462
  transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
463
- model_output = gr.Image(label="模型输出", type="pil")
464
 
465
- # 交互事件绑定
466
- enable_save.click(fn=self.save_image, inputs=None, outputs=None)
467
  generate_btn.click(
468
  fn=self.on_upload,
469
  inputs=[background_img_in, draggable_img_in],
470
  outputs=[html_out, modified_fg_state],
471
  )
472
  model_generate_btn.click(
473
- fn=pipeline.gradio_generate,
 
474
  inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
475
  prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
476
  outputs=model_output
 
400
  canvas_size=400
401
  ), draggable_img
402
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  def create_gui(self):
405
  config = InferenceConfig()
406
  config.lora_id = 'LL3RD/DreamFuse'
407
 
408
+ pipeline = None
409
+ # pipeline = DreamFuseInference(config)
410
+ # pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
411
  """创建 Gradio 界面"""
412
  with gr.Blocks(css=self.css_style) as demo:
413
  modified_fg_state = gr.State()
414
+ gr.Markdown("# DreamFuse: 3 Easy Steps to Create Your Fusion Image")
415
+ gr.Markdown("1. Upload the foreground and background images you want to blend.")
416
+ gr.Markdown("2. Click 'Generate Canvas' to preview the result. You can then drag and resize the foreground object to position it as you like.")
417
+ gr.Markdown("3. Click 'Run Model' to create the final fused image.")
418
  with gr.Row():
419
  with gr.Column(scale=1):
420
+ gr.Markdown("### FG&BG Image Upload")
421
+ background_img_in = gr.Image(label="Background Image", type="pil", height=240, width=240)
422
+ draggable_img_in = gr.Image(label="Foreground Image", type="pil", image_mode="RGBA", height=240, width=240)
423
+ generate_btn = gr.Button("Generate Canvas")
424
 
425
  with gr.Row():
426
  gr.Examples(
 
429
  elem_id="small-examples"
430
  )
431
  with gr.Column(scale=1):
432
+ gr.Markdown("### Preview Region")
433
+ html_out = gr.HTML(label="drag and resize", elem_id="canvas_preview")
434
 
435
  with gr.Row():
436
  with gr.Column(scale=1):
437
+ gr.Markdown("### Parameters")
438
+ seed_slider = gr.Slider(minimum=-1, maximum=100000, step=1, label="Seed", value=12345)
439
  cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
440
  size_select = gr.Radio(
441
  choices=["512", "768", "1024"],
442
  value="512",
443
  label="生成质量(512-差 1024-好)",
444
  )
445
+ prompt_text = gr.Textbox(label="Prompt", placeholder="text prompt", value="")
446
+ text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1, visible=False)
447
+ enable_gui = gr.Checkbox(label="启用GUI", value=True, visible=False)
448
+ enable_truecfg = gr.Checkbox(label="TrueCFG", value=False, visible=False)
 
449
  with gr.Column(scale=1):
450
+ gr.Markdown("### Model Result")
451
+ model_generate_btn = gr.Button("Run Model")
452
  transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
453
+ model_output = gr.Image(label="Model Output", type="pil")
454
 
455
+
 
456
  generate_btn.click(
457
  fn=self.on_upload,
458
  inputs=[background_img_in, draggable_img_in],
459
  outputs=[html_out, modified_fg_state],
460
  )
461
  model_generate_btn.click(
462
+ # fn=pipeline.gradio_generate,
463
+ fn=self.pil_to_base64,
464
  inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
465
  prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
466
  outputs=model_output
dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc CHANGED
Binary files a/dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc and b/dreamfuse/models/dreamfuse_flux/__pycache__/flux_processor.cpython-310.pyc differ
 
dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc and b/dreamfuse/models/dreamfuse_flux/__pycache__/transformer.cpython-310.pyc differ
 
dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc CHANGED
Binary files a/dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc and b/dreamfuse/trains/utils/__pycache__/inference_utils.cpython-310.pyc differ
 
dreamfuse_inference.py CHANGED
@@ -168,7 +168,6 @@ def make_image_grid(images, rows, cols, size=None):
168
  class DreamFuseInference:
169
  def __init__(self, config: InferenceConfig):
170
  self.config = config
171
- print(config.device)
172
  self.device = torch.device(config.device)
173
  torch.backends.cuda.matmul.allow_tf32 = True
174
  seed_everything(config.seed)
@@ -348,16 +347,15 @@ class DreamFuseInference:
348
 
349
  @torch.inference_mode()
350
  def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False):
351
- print("!"*10)
352
- """使用 DreamFuseInference 进行模型推理"""
353
  try:
354
  trans = json.loads(transformation_info)
355
  except:
356
  trans = {}
357
 
358
  size_select = int(size_select)
 
 
359
 
360
- # import pdb; pdb.set_trace()
361
  r, g, b, ori_a = foreground_img.split()
362
  fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans)
363
 
@@ -370,9 +368,7 @@ class DreamFuseInference:
370
  ori_a = ori_a.convert("L")
371
  new_a = new_a.convert("L")
372
  foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a))
373
- print("0"*10)
374
- print(foreground_img.size)
375
- print(background_img.size)
376
  images = self.model_generate(foreground_img.copy(), background_img.copy(),
377
  ori_a, new_a,
378
  enable_mask_affine=enable_gui,
@@ -386,16 +382,15 @@ class DreamFuseInference:
386
  images = Image.fromarray(images[0], "RGB")
387
 
388
  images = images.resize(background_img.size)
389
- images_save = images.copy()
390
 
391
- images.thumbnail((640, 640), Image.LANCZOS)
392
  return images
393
 
394
 
395
  @torch.inference_mode()
396
  def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False):
397
  batch_size = 1
398
- print("-3"*10)
399
  # Prepare images
400
  # adjust bg->fg size
401
  fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size)
@@ -410,7 +405,6 @@ class DreamFuseInference:
410
  new_fg_mask = new_fg_mask.resize(bucket_size)
411
  mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask)
412
 
413
- print("-2"*10)
414
  # Get embeddings
415
  prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt)
416
 
@@ -428,7 +422,6 @@ class DreamFuseInference:
428
  if seed is None:
429
  seed = self.config.seed
430
  generator = torch.Generator(device=self.device).manual_seed(seed)
431
- print("-1"*10)
432
  # Prepare condition latents
433
  condition_image_latents = self._encode_images([fg_image, bg_image])
434
 
@@ -445,7 +438,6 @@ class DreamFuseInference:
445
  )
446
  )
447
 
448
- print(1)
449
  if mask_affine is not None:
450
  affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2
451
  scale_factor = 1 / 16
@@ -457,7 +449,7 @@ class DreamFuseInference:
457
  scale_factor=scale_factor, device=self.device,
458
  )
459
  cond_latent_image_ids = torch.stack(cond_latent_image_ids)
460
- print(2)
461
  # Pack condition latents
462
  cond_image_latents = self._pack_latents(condition_image_latents)
463
  cond_input = {
@@ -470,7 +462,7 @@ class DreamFuseInference:
470
  latents, latent_image_ids = self._prepare_latents(
471
  batch_size, num_channels_latents, height, width, generator
472
  )
473
- print(3)
474
  # Setup timesteps
475
  sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps)
476
  image_seq_len = latents.shape[1]
@@ -488,7 +480,7 @@ class DreamFuseInference:
488
  sigmas=sigmas,
489
  mu=mu,
490
  )
491
- print(4)
492
  # Denoising loop
493
  for i, t in enumerate(timesteps):
494
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -537,12 +529,12 @@ class DreamFuseInference:
537
 
538
  # Compute previous noisy sample
539
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
540
- print(5)
541
  # Decode latents
542
  latents = self._unpack_latents(latents, height, width)
543
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
544
  images = self.vae.decode(latents, return_dict=False)[0]
545
- print(6)
546
  # Post-process images
547
  images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
548
  return images
@@ -575,68 +567,3 @@ class DreamFuseInference:
575
  offset=None
576
  )
577
  return latents, latent_image_ids
578
-
579
- def main():
580
- parser = transformers.HfArgumentParser(InferenceConfig)
581
- config: InferenceConfig = parser.parse_args_into_dataclasses()[0]
582
- model = DreamFuseInference(config)
583
- os.makedirs(config.valid_output_dir, exist_ok=True)
584
- for valid_root, valid_json in zip(config.valid_roots, config.valid_jsons):
585
- with open(valid_json, 'r') as f:
586
- valid_info = json.load(f)
587
-
588
- # multi gpu
589
- to_process = sorted(list(valid_info.keys()))
590
-
591
- # debug
592
- to_process = [k for k in to_process if "data_wear" in k and "pixelwave" in k]
593
- # debug
594
-
595
- sd_idx = len(to_process) // config.total_num * config.sub_idx
596
- ed_idx = len(to_process) // config.total_num * (config.sub_idx + 1)
597
- if config.sub_idx < config.total_num - 1:
598
- print(config.sub_idx, sd_idx, ed_idx)
599
- to_process = to_process[sd_idx:ed_idx]
600
- else:
601
- print(config.sub_idx, sd_idx)
602
- to_process = to_process[sd_idx:]
603
- valid_info = {k: valid_info[k] for k in to_process}
604
-
605
- for meta_key, info in tqdm(valid_info.items()):
606
- img_name = meta_key.split('/')[-1]
607
-
608
- foreground_img = Image.open(os.path.join(valid_root, info['img_info']['000']))
609
- background_img = Image.open(os.path.join(valid_root, info['img_info']['001']))
610
-
611
- new_fg_mask = Image.open(os.path.join(valid_root, info['img_mask_info']['000_mask_scale']))
612
- ori_fg_mask = Image.open(os.path.join(valid_root, info['img_mask_info']['000']))
613
-
614
- # debug
615
- foreground_img.save(os.path.join(config.valid_output_dir, f"{img_name}_0.png"))
616
- background_img.save(os.path.join(config.valid_output_dir, f"{img_name}_1.png"))
617
- ori_fg_mask.save(os.path.join(config.valid_output_dir, f"{img_name}_0_mask.png"))
618
- new_fg_mask.save(os.path.join(config.valid_output_dir, f"{img_name}_0_mask_scale.png"))
619
- # debug
620
-
621
- foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_fg_mask))
622
-
623
- images = model(foreground_img.copy(), background_img.copy(),
624
- ori_fg_mask, new_fg_mask,
625
- prompt=config.ref_prompts,
626
- seed=config.seed,
627
- cfg=config.guidance_scale,
628
- size_select=config.inference_scale,
629
- text_strength=config.text_strength,
630
- truecfg=config.truecfg)
631
-
632
- result_image = Image.fromarray(images[0], "RGB")
633
- result_image = result_image.resize(background_img.size)
634
- result_image.save(os.path.join(config.valid_output_dir, f"{img_name}_2.png"))
635
- # Make grid
636
- grid_image = [foreground_img, background_img] + [result_image]
637
- result = make_image_grid(grid_image, 1, len(grid_image), size=result_image.size)
638
-
639
- result.save(os.path.join(config.valid_output_dir, f"{img_name}.jpg"))
640
-
641
- if __name__ == "__main__":
642
- main()
 
168
  class DreamFuseInference:
169
  def __init__(self, config: InferenceConfig):
170
  self.config = config
 
171
  self.device = torch.device(config.device)
172
  torch.backends.cuda.matmul.allow_tf32 = True
173
  seed_everything(config.seed)
 
347
 
348
  @torch.inference_mode()
349
  def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False):
 
 
350
  try:
351
  trans = json.loads(transformation_info)
352
  except:
353
  trans = {}
354
 
355
  size_select = int(size_select)
356
+ if size_select == 1024: text_strength = 5
357
+ if size_select == 768: text_strength = 3
358
 
 
359
  r, g, b, ori_a = foreground_img.split()
360
  fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans)
361
 
 
368
  ori_a = ori_a.convert("L")
369
  new_a = new_a.convert("L")
370
  foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a))
371
+
 
 
372
  images = self.model_generate(foreground_img.copy(), background_img.copy(),
373
  ori_a, new_a,
374
  enable_mask_affine=enable_gui,
 
382
  images = Image.fromarray(images[0], "RGB")
383
 
384
  images = images.resize(background_img.size)
385
+ # images_save = images.copy()
386
 
387
+ # images.thumbnail((640, 640), Image.LANCZOS)
388
  return images
389
 
390
 
391
  @torch.inference_mode()
392
  def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False):
393
  batch_size = 1
 
394
  # Prepare images
395
  # adjust bg->fg size
396
  fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size)
 
405
  new_fg_mask = new_fg_mask.resize(bucket_size)
406
  mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask)
407
 
 
408
  # Get embeddings
409
  prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt)
410
 
 
422
  if seed is None:
423
  seed = self.config.seed
424
  generator = torch.Generator(device=self.device).manual_seed(seed)
 
425
  # Prepare condition latents
426
  condition_image_latents = self._encode_images([fg_image, bg_image])
427
 
 
438
  )
439
  )
440
 
 
441
  if mask_affine is not None:
442
  affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2
443
  scale_factor = 1 / 16
 
449
  scale_factor=scale_factor, device=self.device,
450
  )
451
  cond_latent_image_ids = torch.stack(cond_latent_image_ids)
452
+
453
  # Pack condition latents
454
  cond_image_latents = self._pack_latents(condition_image_latents)
455
  cond_input = {
 
462
  latents, latent_image_ids = self._prepare_latents(
463
  batch_size, num_channels_latents, height, width, generator
464
  )
465
+
466
  # Setup timesteps
467
  sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps)
468
  image_seq_len = latents.shape[1]
 
480
  sigmas=sigmas,
481
  mu=mu,
482
  )
483
+
484
  # Denoising loop
485
  for i, t in enumerate(timesteps):
486
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
 
529
 
530
  # Compute previous noisy sample
531
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
532
+
533
  # Decode latents
534
  latents = self._unpack_latents(latents, height, width)
535
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
536
  images = self.vae.decode(latents, return_dict=False)[0]
537
+
538
  # Post-process images
539
  images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
540
  return images
 
567
  offset=None
568
  )
569
  return latents, latent_image_ids