Merge branch 'main' of hf.co:spaces/ACE-Step/ACE-Step
Browse files- app.py +1 -1
- pipeline_ace_step.py +38 -1
app.py
CHANGED
|
@@ -12,7 +12,7 @@ parser.add_argument("--port", type=int, default=7860)
|
|
| 12 |
parser.add_argument("--device_id", type=int, default=0)
|
| 13 |
parser.add_argument("--share", action='store_true', default=False)
|
| 14 |
parser.add_argument("--bf16", action='store_true', default=True)
|
| 15 |
-
parser.add_argument("--torch_compile", type=bool, default=
|
| 16 |
|
| 17 |
args = parser.parse_args()
|
| 18 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
|
|
|
| 12 |
parser.add_argument("--device_id", type=int, default=0)
|
| 13 |
parser.add_argument("--share", action='store_true', default=False)
|
| 14 |
parser.add_argument("--bf16", action='store_true', default=True)
|
| 15 |
+
parser.add_argument("--torch_compile", type=bool, default=False)
|
| 16 |
|
| 17 |
args = parser.parse_args()
|
| 18 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
pipeline_ace_step.py
CHANGED
|
@@ -537,6 +537,27 @@ class ACEStepPipeline:
|
|
| 537 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
| 538 |
return target_latents
|
| 539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
@torch.no_grad()
|
| 541 |
def text2music_diffusion_process(
|
| 542 |
self,
|
|
@@ -569,6 +590,9 @@ class ACEStepPipeline:
|
|
| 569 |
repaint_start=0,
|
| 570 |
repaint_end=0,
|
| 571 |
src_latents=None,
|
|
|
|
|
|
|
|
|
|
| 572 |
):
|
| 573 |
|
| 574 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
@@ -600,6 +624,9 @@ class ACEStepPipeline:
|
|
| 600 |
if src_latents is not None:
|
| 601 |
frame_length = src_latents.shape[-1]
|
| 602 |
|
|
|
|
|
|
|
|
|
|
| 603 |
if len(oss_steps) > 0:
|
| 604 |
infer_steps = max(oss_steps)
|
| 605 |
scheduler.set_timesteps
|
|
@@ -695,6 +722,10 @@ class ACEStepPipeline:
|
|
| 695 |
zt_edit = x0.clone()
|
| 696 |
z0 = target_latents
|
| 697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 699 |
|
| 700 |
# guidance interval
|
|
@@ -798,7 +829,10 @@ class ACEStepPipeline:
|
|
| 798 |
return sample
|
| 799 |
|
| 800 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 801 |
-
|
|
|
|
|
|
|
|
|
|
| 802 |
if is_repaint:
|
| 803 |
if i < n_min:
|
| 804 |
continue
|
|
@@ -1014,6 +1048,9 @@ class ACEStepPipeline:
|
|
| 1014 |
|
| 1015 |
start_time = time.time()
|
| 1016 |
|
|
|
|
|
|
|
|
|
|
| 1017 |
if not self.loaded:
|
| 1018 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
| 1019 |
self.load_checkpoint(self.checkpoint_dir)
|
|
|
|
| 537 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
| 538 |
return target_latents
|
| 539 |
|
| 540 |
+
def add_latents_noise(
|
| 541 |
+
self,
|
| 542 |
+
gt_latents,
|
| 543 |
+
variance,
|
| 544 |
+
noise,
|
| 545 |
+
scheduler,
|
| 546 |
+
):
|
| 547 |
+
|
| 548 |
+
bsz = gt_latents.shape[0]
|
| 549 |
+
u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
|
| 550 |
+
indices = (u * scheduler.config.num_train_timesteps).long()
|
| 551 |
+
timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
|
| 552 |
+
indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
|
| 553 |
+
nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
|
| 554 |
+
sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
|
| 555 |
+
while len(sigma.shape) < gt_latents.ndim:
|
| 556 |
+
sigma = sigma.unsqueeze(-1)
|
| 557 |
+
noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
|
| 558 |
+
init_timestep = indices[0]
|
| 559 |
+
return noisy_image, init_timestep
|
| 560 |
+
|
| 561 |
@torch.no_grad()
|
| 562 |
def text2music_diffusion_process(
|
| 563 |
self,
|
|
|
|
| 590 |
repaint_start=0,
|
| 591 |
repaint_end=0,
|
| 592 |
src_latents=None,
|
| 593 |
+
audio2audio_enable=False,
|
| 594 |
+
ref_audio_strength=0.5,
|
| 595 |
+
ref_latents=None,
|
| 596 |
):
|
| 597 |
|
| 598 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
|
|
| 624 |
if src_latents is not None:
|
| 625 |
frame_length = src_latents.shape[-1]
|
| 626 |
|
| 627 |
+
if ref_latents is not None:
|
| 628 |
+
frame_length = ref_latents.shape[-1]
|
| 629 |
+
|
| 630 |
if len(oss_steps) > 0:
|
| 631 |
infer_steps = max(oss_steps)
|
| 632 |
scheduler.set_timesteps
|
|
|
|
| 722 |
zt_edit = x0.clone()
|
| 723 |
z0 = target_latents
|
| 724 |
|
| 725 |
+
init_timestep = 1000
|
| 726 |
+
if audio2audio_enable and ref_latents is not None:
|
| 727 |
+
target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
|
| 728 |
+
|
| 729 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 730 |
|
| 731 |
# guidance interval
|
|
|
|
| 829 |
return sample
|
| 830 |
|
| 831 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 832 |
+
|
| 833 |
+
if t > init_timestep:
|
| 834 |
+
continue
|
| 835 |
+
|
| 836 |
if is_repaint:
|
| 837 |
if i < n_min:
|
| 838 |
continue
|
|
|
|
| 1048 |
|
| 1049 |
start_time = time.time()
|
| 1050 |
|
| 1051 |
+
if audio2audio_enable and ref_audio_input is not None:
|
| 1052 |
+
task = "audio2audio"
|
| 1053 |
+
|
| 1054 |
if not self.loaded:
|
| 1055 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
| 1056 |
self.load_checkpoint(self.checkpoint_dir)
|