Sayoyo commited on
Commit
5800fe0
·
2 Parent(s): 96ec844 a9fd5eb

Merge branch 'main' of hf.co:spaces/ACE-Step/ACE-Step

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. pipeline_ace_step.py +38 -1
app.py CHANGED
@@ -12,7 +12,7 @@ parser.add_argument("--port", type=int, default=7860)
12
  parser.add_argument("--device_id", type=int, default=0)
13
  parser.add_argument("--share", action='store_true', default=False)
14
  parser.add_argument("--bf16", action='store_true', default=True)
15
- parser.add_argument("--torch_compile", type=bool, default=True)
16
 
17
  args = parser.parse_args()
18
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
 
12
  parser.add_argument("--device_id", type=int, default=0)
13
  parser.add_argument("--share", action='store_true', default=False)
14
  parser.add_argument("--bf16", action='store_true', default=True)
15
+ parser.add_argument("--torch_compile", type=bool, default=False)
16
 
17
  args = parser.parse_args()
18
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
pipeline_ace_step.py CHANGED
@@ -537,6 +537,27 @@ class ACEStepPipeline:
537
  target_latents = zt_edit if xt_tar is None else xt_tar
538
  return target_latents
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  @torch.no_grad()
541
  def text2music_diffusion_process(
542
  self,
@@ -569,6 +590,9 @@ class ACEStepPipeline:
569
  repaint_start=0,
570
  repaint_end=0,
571
  src_latents=None,
 
 
 
572
  ):
573
 
574
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
@@ -600,6 +624,9 @@ class ACEStepPipeline:
600
  if src_latents is not None:
601
  frame_length = src_latents.shape[-1]
602
 
 
 
 
603
  if len(oss_steps) > 0:
604
  infer_steps = max(oss_steps)
605
  scheduler.set_timesteps
@@ -695,6 +722,10 @@ class ACEStepPipeline:
695
  zt_edit = x0.clone()
696
  z0 = target_latents
697
 
 
 
 
 
698
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
699
 
700
  # guidance interval
@@ -798,7 +829,10 @@ class ACEStepPipeline:
798
  return sample
799
 
800
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
801
-
 
 
 
802
  if is_repaint:
803
  if i < n_min:
804
  continue
@@ -1014,6 +1048,9 @@ class ACEStepPipeline:
1014
 
1015
  start_time = time.time()
1016
 
 
 
 
1017
  if not self.loaded:
1018
  logger.warning("Checkpoint not loaded, loading checkpoint...")
1019
  self.load_checkpoint(self.checkpoint_dir)
 
537
  target_latents = zt_edit if xt_tar is None else xt_tar
538
  return target_latents
539
 
540
+ def add_latents_noise(
541
+ self,
542
+ gt_latents,
543
+ variance,
544
+ noise,
545
+ scheduler,
546
+ ):
547
+
548
+ bsz = gt_latents.shape[0]
549
+ u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
550
+ indices = (u * scheduler.config.num_train_timesteps).long()
551
+ timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
552
+ indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
553
+ nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
554
+ sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
555
+ while len(sigma.shape) < gt_latents.ndim:
556
+ sigma = sigma.unsqueeze(-1)
557
+ noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
558
+ init_timestep = indices[0]
559
+ return noisy_image, init_timestep
560
+
561
  @torch.no_grad()
562
  def text2music_diffusion_process(
563
  self,
 
590
  repaint_start=0,
591
  repaint_end=0,
592
  src_latents=None,
593
+ audio2audio_enable=False,
594
+ ref_audio_strength=0.5,
595
+ ref_latents=None,
596
  ):
597
 
598
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 
624
  if src_latents is not None:
625
  frame_length = src_latents.shape[-1]
626
 
627
+ if ref_latents is not None:
628
+ frame_length = ref_latents.shape[-1]
629
+
630
  if len(oss_steps) > 0:
631
  infer_steps = max(oss_steps)
632
  scheduler.set_timesteps
 
722
  zt_edit = x0.clone()
723
  z0 = target_latents
724
 
725
+ init_timestep = 1000
726
+ if audio2audio_enable and ref_latents is not None:
727
+ target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
728
+
729
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
730
 
731
  # guidance interval
 
829
  return sample
830
 
831
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
832
+
833
+ if t > init_timestep:
834
+ continue
835
+
836
  if is_repaint:
837
  if i < n_min:
838
  continue
 
1048
 
1049
  start_time = time.time()
1050
 
1051
+ if audio2audio_enable and ref_audio_input is not None:
1052
+ task = "audio2audio"
1053
+
1054
  if not self.loaded:
1055
  logger.warning("Checkpoint not loaded, loading checkpoint...")
1056
  self.load_checkpoint(self.checkpoint_dir)