Spaces:
Runtime error
Runtime error
Gong Junmin
commited on
Commit
·
b0e2210
1
Parent(s):
5c24327
support audio2audio
Browse files- pipeline_ace_step.py +57 -3
- ui/components.py +47 -0
pipeline_ace_step.py
CHANGED
|
@@ -2,7 +2,7 @@ import random
|
|
| 2 |
import time
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
-
import spaces
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from loguru import logger
|
|
@@ -522,6 +522,27 @@ class ACEStepPipeline:
|
|
| 522 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
| 523 |
return target_latents
|
| 524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
@torch.no_grad()
|
| 526 |
def text2music_diffusion_process(
|
| 527 |
self,
|
|
@@ -554,6 +575,9 @@ class ACEStepPipeline:
|
|
| 554 |
repaint_start=0,
|
| 555 |
repaint_end=0,
|
| 556 |
src_latents=None,
|
|
|
|
|
|
|
|
|
|
| 557 |
):
|
| 558 |
|
| 559 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
@@ -585,6 +609,9 @@ class ACEStepPipeline:
|
|
| 585 |
if src_latents is not None:
|
| 586 |
frame_length = src_latents.shape[-1]
|
| 587 |
|
|
|
|
|
|
|
|
|
|
| 588 |
if len(oss_steps) > 0:
|
| 589 |
infer_steps = max(oss_steps)
|
| 590 |
scheduler.set_timesteps
|
|
@@ -680,6 +707,10 @@ class ACEStepPipeline:
|
|
| 680 |
zt_edit = x0.clone()
|
| 681 |
z0 = target_latents
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 684 |
|
| 685 |
# guidance interval
|
|
@@ -783,7 +814,10 @@ class ACEStepPipeline:
|
|
| 783 |
return sample
|
| 784 |
|
| 785 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 786 |
-
|
|
|
|
|
|
|
|
|
|
| 787 |
if is_repaint:
|
| 788 |
if i < n_min:
|
| 789 |
continue
|
|
@@ -955,7 +989,7 @@ class ACEStepPipeline:
|
|
| 955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 956 |
return latents
|
| 957 |
|
| 958 |
-
@spaces.GPU
|
| 959 |
def __call__(
|
| 960 |
self,
|
| 961 |
audio_duration: float = 60.0,
|
|
@@ -976,6 +1010,9 @@ class ACEStepPipeline:
|
|
| 976 |
oss_steps: str = None,
|
| 977 |
guidance_scale_text: float = 0.0,
|
| 978 |
guidance_scale_lyric: float = 0.0,
|
|
|
|
|
|
|
|
|
|
| 979 |
retake_seeds: list = None,
|
| 980 |
retake_variance: float = 0.5,
|
| 981 |
task: str = "text2music",
|
|
@@ -995,6 +1032,9 @@ class ACEStepPipeline:
|
|
| 995 |
|
| 996 |
start_time = time.time()
|
| 997 |
|
|
|
|
|
|
|
|
|
|
| 998 |
if not self.loaded:
|
| 999 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
| 1000 |
self.load_checkpoint(self.checkpoint_dir)
|
|
@@ -1053,6 +1093,14 @@ class ACEStepPipeline:
|
|
| 1053 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
| 1054 |
src_latents = self.infer_latents(src_audio_path)
|
| 1055 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
if task == "edit":
|
| 1057 |
texts = [edit_target_prompt]
|
| 1058 |
target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
|
|
@@ -1117,6 +1165,9 @@ class ACEStepPipeline:
|
|
| 1117 |
repaint_start=repaint_start,
|
| 1118 |
repaint_end=repaint_end,
|
| 1119 |
src_latents=src_latents,
|
|
|
|
|
|
|
|
|
|
| 1120 |
)
|
| 1121 |
|
| 1122 |
end_time = time.time()
|
|
@@ -1169,6 +1220,9 @@ class ACEStepPipeline:
|
|
| 1169 |
"src_audio_path": src_audio_path,
|
| 1170 |
"edit_target_prompt": edit_target_prompt,
|
| 1171 |
"edit_target_lyrics": edit_target_lyrics,
|
|
|
|
|
|
|
|
|
|
| 1172 |
}
|
| 1173 |
# save input_params_json
|
| 1174 |
for output_audio_path in output_paths:
|
|
|
|
| 2 |
import time
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
+
# import spaces
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from loguru import logger
|
|
|
|
| 522 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
| 523 |
return target_latents
|
| 524 |
|
| 525 |
+
def add_latents_noise(
|
| 526 |
+
self,
|
| 527 |
+
gt_latents,
|
| 528 |
+
variance,
|
| 529 |
+
noise,
|
| 530 |
+
scheduler,
|
| 531 |
+
):
|
| 532 |
+
|
| 533 |
+
bsz = gt_latents.shape[0]
|
| 534 |
+
u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
|
| 535 |
+
indices = (u * scheduler.config.num_train_timesteps).long()
|
| 536 |
+
timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
|
| 537 |
+
indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
|
| 538 |
+
nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
|
| 539 |
+
sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
|
| 540 |
+
while len(sigma.shape) < gt_latents.ndim:
|
| 541 |
+
sigma = sigma.unsqueeze(-1)
|
| 542 |
+
noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
|
| 543 |
+
init_timestep = indices[0]
|
| 544 |
+
return noisy_image, init_timestep
|
| 545 |
+
|
| 546 |
@torch.no_grad()
|
| 547 |
def text2music_diffusion_process(
|
| 548 |
self,
|
|
|
|
| 575 |
repaint_start=0,
|
| 576 |
repaint_end=0,
|
| 577 |
src_latents=None,
|
| 578 |
+
audio2audio_enable=False,
|
| 579 |
+
ref_audio_strength=0.5,
|
| 580 |
+
ref_latents=None,
|
| 581 |
):
|
| 582 |
|
| 583 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
|
|
| 609 |
if src_latents is not None:
|
| 610 |
frame_length = src_latents.shape[-1]
|
| 611 |
|
| 612 |
+
if ref_latents is not None:
|
| 613 |
+
frame_length = ref_latents.shape[-1]
|
| 614 |
+
|
| 615 |
if len(oss_steps) > 0:
|
| 616 |
infer_steps = max(oss_steps)
|
| 617 |
scheduler.set_timesteps
|
|
|
|
| 707 |
zt_edit = x0.clone()
|
| 708 |
z0 = target_latents
|
| 709 |
|
| 710 |
+
init_timestep = 1000
|
| 711 |
+
if audio2audio_enable and ref_latents is not None:
|
| 712 |
+
target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
|
| 713 |
+
|
| 714 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 715 |
|
| 716 |
# guidance interval
|
|
|
|
| 814 |
return sample
|
| 815 |
|
| 816 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 817 |
+
|
| 818 |
+
if t > init_timestep:
|
| 819 |
+
continue
|
| 820 |
+
|
| 821 |
if is_repaint:
|
| 822 |
if i < n_min:
|
| 823 |
continue
|
|
|
|
| 989 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 990 |
return latents
|
| 991 |
|
| 992 |
+
# @spaces.GPU
|
| 993 |
def __call__(
|
| 994 |
self,
|
| 995 |
audio_duration: float = 60.0,
|
|
|
|
| 1010 |
oss_steps: str = None,
|
| 1011 |
guidance_scale_text: float = 0.0,
|
| 1012 |
guidance_scale_lyric: float = 0.0,
|
| 1013 |
+
audio2audio_enable: bool = False,
|
| 1014 |
+
ref_audio_strength: float = 0.5,
|
| 1015 |
+
ref_audio_input: str = None,
|
| 1016 |
retake_seeds: list = None,
|
| 1017 |
retake_variance: float = 0.5,
|
| 1018 |
task: str = "text2music",
|
|
|
|
| 1032 |
|
| 1033 |
start_time = time.time()
|
| 1034 |
|
| 1035 |
+
if audio2audio_enable and ref_audio_input is not None:
|
| 1036 |
+
task = "audio2audio"
|
| 1037 |
+
|
| 1038 |
if not self.loaded:
|
| 1039 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
| 1040 |
self.load_checkpoint(self.checkpoint_dir)
|
|
|
|
| 1093 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
| 1094 |
src_latents = self.infer_latents(src_audio_path)
|
| 1095 |
|
| 1096 |
+
ref_latents = None
|
| 1097 |
+
if ref_audio_input is not None and audio2audio_enable:
|
| 1098 |
+
assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
|
| 1099 |
+
assert os.path.exists(
|
| 1100 |
+
ref_audio_input
|
| 1101 |
+
), f"ref_audio_input {ref_audio_input} does not exist"
|
| 1102 |
+
ref_latents = self.infer_latents(ref_audio_input)
|
| 1103 |
+
|
| 1104 |
if task == "edit":
|
| 1105 |
texts = [edit_target_prompt]
|
| 1106 |
target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
|
|
|
|
| 1165 |
repaint_start=repaint_start,
|
| 1166 |
repaint_end=repaint_end,
|
| 1167 |
src_latents=src_latents,
|
| 1168 |
+
audio2audio_enable=audio2audio_enable,
|
| 1169 |
+
ref_audio_strength=ref_audio_strength,
|
| 1170 |
+
ref_latents=ref_latents,
|
| 1171 |
)
|
| 1172 |
|
| 1173 |
end_time = time.time()
|
|
|
|
| 1220 |
"src_audio_path": src_audio_path,
|
| 1221 |
"edit_target_prompt": edit_target_prompt,
|
| 1222 |
"edit_target_lyrics": edit_target_lyrics,
|
| 1223 |
+
"audio2audio_enable": audio2audio_enable,
|
| 1224 |
+
"ref_audio_strength": ref_audio_strength,
|
| 1225 |
+
"ref_audio_input": ref_audio_input,
|
| 1226 |
}
|
| 1227 |
# save input_params_json
|
| 1228 |
for output_audio_path in output_paths:
|
ui/components.py
CHANGED
|
@@ -71,6 +71,32 @@ def create_text2music_ui(
|
|
| 71 |
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
| 72 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
| 75 |
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
|
| 76 |
|
|
@@ -533,6 +559,21 @@ def create_text2music_ui(
|
|
| 533 |
", ".join(map(str, json_data["oss_steps"])),
|
| 534 |
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
| 535 |
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
)
|
| 537 |
|
| 538 |
sample_bnt.click(
|
|
@@ -556,6 +597,9 @@ def create_text2music_ui(
|
|
| 556 |
oss_steps,
|
| 557 |
guidance_scale_text,
|
| 558 |
guidance_scale_lyric,
|
|
|
|
|
|
|
|
|
|
| 559 |
],
|
| 560 |
)
|
| 561 |
|
|
@@ -580,6 +624,9 @@ def create_text2music_ui(
|
|
| 580 |
oss_steps,
|
| 581 |
guidance_scale_text,
|
| 582 |
guidance_scale_lyric,
|
|
|
|
|
|
|
|
|
|
| 583 |
], outputs=outputs + [input_params_json]
|
| 584 |
)
|
| 585 |
|
|
|
|
| 71 |
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
| 72 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
| 73 |
|
| 74 |
+
# audio2audio
|
| 75 |
+
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
|
| 76 |
+
ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True)
|
| 77 |
+
ref_audio_strength = gr.Slider(
|
| 78 |
+
label="Refer audio strength",
|
| 79 |
+
minimum=0.0,
|
| 80 |
+
maximum=1.0,
|
| 81 |
+
step=0.01,
|
| 82 |
+
value=0.5,
|
| 83 |
+
elem_id="ref_audio_strength",
|
| 84 |
+
visible=False,
|
| 85 |
+
interactive=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def toggle_ref_audio_visibility(is_checked):
|
| 89 |
+
return (
|
| 90 |
+
gr.update(visible=is_checked, elem_id="ref_audio_input"),
|
| 91 |
+
gr.update(visible=is_checked, elem_id="ref_audio_strength"),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
audio2audio_enable.change(
|
| 95 |
+
fn=toggle_ref_audio_visibility,
|
| 96 |
+
inputs=[audio2audio_enable],
|
| 97 |
+
outputs=[ref_audio_input, ref_audio_strength],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
| 101 |
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
|
| 102 |
|
|
|
|
| 559 |
", ".join(map(str, json_data["oss_steps"])),
|
| 560 |
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
| 561 |
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
| 562 |
+
(
|
| 563 |
+
json_data["audio2audio_enable"]
|
| 564 |
+
if "audio2audio_enable" in json_data
|
| 565 |
+
else False
|
| 566 |
+
),
|
| 567 |
+
(
|
| 568 |
+
json_data["ref_audio_strength"]
|
| 569 |
+
if "ref_audio_strength" in json_data
|
| 570 |
+
else 0.5
|
| 571 |
+
),
|
| 572 |
+
(
|
| 573 |
+
json_data["ref_audio_input"]
|
| 574 |
+
if "ref_audio_input" in json_data
|
| 575 |
+
else None
|
| 576 |
+
),
|
| 577 |
)
|
| 578 |
|
| 579 |
sample_bnt.click(
|
|
|
|
| 597 |
oss_steps,
|
| 598 |
guidance_scale_text,
|
| 599 |
guidance_scale_lyric,
|
| 600 |
+
audio2audio_enable,
|
| 601 |
+
ref_audio_strength,
|
| 602 |
+
ref_audio_input,
|
| 603 |
],
|
| 604 |
)
|
| 605 |
|
|
|
|
| 624 |
oss_steps,
|
| 625 |
guidance_scale_text,
|
| 626 |
guidance_scale_lyric,
|
| 627 |
+
audio2audio_enable,
|
| 628 |
+
ref_audio_strength,
|
| 629 |
+
ref_audio_input,
|
| 630 |
], outputs=outputs + [input_params_json]
|
| 631 |
)
|
| 632 |
|