Spaces:
Runtime error
Runtime error
Merge branch 'main' of hf.co:spaces/ACE-Step/ACE-Step
Browse files- app.py +1 -1
- pipeline_ace_step.py +38 -1
app.py
CHANGED
@@ -12,7 +12,7 @@ parser.add_argument("--port", type=int, default=7860)
|
|
12 |
parser.add_argument("--device_id", type=int, default=0)
|
13 |
parser.add_argument("--share", action='store_true', default=False)
|
14 |
parser.add_argument("--bf16", action='store_true', default=True)
|
15 |
-
parser.add_argument("--torch_compile", type=bool, default=
|
16 |
|
17 |
args = parser.parse_args()
|
18 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
|
|
12 |
parser.add_argument("--device_id", type=int, default=0)
|
13 |
parser.add_argument("--share", action='store_true', default=False)
|
14 |
parser.add_argument("--bf16", action='store_true', default=True)
|
15 |
+
parser.add_argument("--torch_compile", type=bool, default=False)
|
16 |
|
17 |
args = parser.parse_args()
|
18 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
pipeline_ace_step.py
CHANGED
@@ -537,6 +537,27 @@ class ACEStepPipeline:
|
|
537 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
538 |
return target_latents
|
539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
@torch.no_grad()
|
541 |
def text2music_diffusion_process(
|
542 |
self,
|
@@ -569,6 +590,9 @@ class ACEStepPipeline:
|
|
569 |
repaint_start=0,
|
570 |
repaint_end=0,
|
571 |
src_latents=None,
|
|
|
|
|
|
|
572 |
):
|
573 |
|
574 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
@@ -600,6 +624,9 @@ class ACEStepPipeline:
|
|
600 |
if src_latents is not None:
|
601 |
frame_length = src_latents.shape[-1]
|
602 |
|
|
|
|
|
|
|
603 |
if len(oss_steps) > 0:
|
604 |
infer_steps = max(oss_steps)
|
605 |
scheduler.set_timesteps
|
@@ -695,6 +722,10 @@ class ACEStepPipeline:
|
|
695 |
zt_edit = x0.clone()
|
696 |
z0 = target_latents
|
697 |
|
|
|
|
|
|
|
|
|
698 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
699 |
|
700 |
# guidance interval
|
@@ -798,7 +829,10 @@ class ACEStepPipeline:
|
|
798 |
return sample
|
799 |
|
800 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
801 |
-
|
|
|
|
|
|
|
802 |
if is_repaint:
|
803 |
if i < n_min:
|
804 |
continue
|
@@ -1014,6 +1048,9 @@ class ACEStepPipeline:
|
|
1014 |
|
1015 |
start_time = time.time()
|
1016 |
|
|
|
|
|
|
|
1017 |
if not self.loaded:
|
1018 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
1019 |
self.load_checkpoint(self.checkpoint_dir)
|
|
|
537 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
538 |
return target_latents
|
539 |
|
540 |
+
def add_latents_noise(
|
541 |
+
self,
|
542 |
+
gt_latents,
|
543 |
+
variance,
|
544 |
+
noise,
|
545 |
+
scheduler,
|
546 |
+
):
|
547 |
+
|
548 |
+
bsz = gt_latents.shape[0]
|
549 |
+
u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
|
550 |
+
indices = (u * scheduler.config.num_train_timesteps).long()
|
551 |
+
timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
|
552 |
+
indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
|
553 |
+
nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
|
554 |
+
sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
|
555 |
+
while len(sigma.shape) < gt_latents.ndim:
|
556 |
+
sigma = sigma.unsqueeze(-1)
|
557 |
+
noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
|
558 |
+
init_timestep = indices[0]
|
559 |
+
return noisy_image, init_timestep
|
560 |
+
|
561 |
@torch.no_grad()
|
562 |
def text2music_diffusion_process(
|
563 |
self,
|
|
|
590 |
repaint_start=0,
|
591 |
repaint_end=0,
|
592 |
src_latents=None,
|
593 |
+
audio2audio_enable=False,
|
594 |
+
ref_audio_strength=0.5,
|
595 |
+
ref_latents=None,
|
596 |
):
|
597 |
|
598 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
|
624 |
if src_latents is not None:
|
625 |
frame_length = src_latents.shape[-1]
|
626 |
|
627 |
+
if ref_latents is not None:
|
628 |
+
frame_length = ref_latents.shape[-1]
|
629 |
+
|
630 |
if len(oss_steps) > 0:
|
631 |
infer_steps = max(oss_steps)
|
632 |
scheduler.set_timesteps
|
|
|
722 |
zt_edit = x0.clone()
|
723 |
z0 = target_latents
|
724 |
|
725 |
+
init_timestep = 1000
|
726 |
+
if audio2audio_enable and ref_latents is not None:
|
727 |
+
target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
|
728 |
+
|
729 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
730 |
|
731 |
# guidance interval
|
|
|
829 |
return sample
|
830 |
|
831 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
832 |
+
|
833 |
+
if t > init_timestep:
|
834 |
+
continue
|
835 |
+
|
836 |
if is_repaint:
|
837 |
if i < n_min:
|
838 |
continue
|
|
|
1048 |
|
1049 |
start_time = time.time()
|
1050 |
|
1051 |
+
if audio2audio_enable and ref_audio_input is not None:
|
1052 |
+
task = "audio2audio"
|
1053 |
+
|
1054 |
if not self.loaded:
|
1055 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
1056 |
self.load_checkpoint(self.checkpoint_dir)
|