Spaces:
Runtime error
Runtime error
Gong Junmin
commited on
Commit
·
b0e2210
1
Parent(s):
5c24327
support audio2audio
Browse files- pipeline_ace_step.py +57 -3
- ui/components.py +47 -0
pipeline_ace_step.py
CHANGED
@@ -2,7 +2,7 @@ import random
|
|
2 |
import time
|
3 |
import os
|
4 |
import re
|
5 |
-
import spaces
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from loguru import logger
|
@@ -522,6 +522,27 @@ class ACEStepPipeline:
|
|
522 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
523 |
return target_latents
|
524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
@torch.no_grad()
|
526 |
def text2music_diffusion_process(
|
527 |
self,
|
@@ -554,6 +575,9 @@ class ACEStepPipeline:
|
|
554 |
repaint_start=0,
|
555 |
repaint_end=0,
|
556 |
src_latents=None,
|
|
|
|
|
|
|
557 |
):
|
558 |
|
559 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
@@ -585,6 +609,9 @@ class ACEStepPipeline:
|
|
585 |
if src_latents is not None:
|
586 |
frame_length = src_latents.shape[-1]
|
587 |
|
|
|
|
|
|
|
588 |
if len(oss_steps) > 0:
|
589 |
infer_steps = max(oss_steps)
|
590 |
scheduler.set_timesteps
|
@@ -680,6 +707,10 @@ class ACEStepPipeline:
|
|
680 |
zt_edit = x0.clone()
|
681 |
z0 = target_latents
|
682 |
|
|
|
|
|
|
|
|
|
683 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
684 |
|
685 |
# guidance interval
|
@@ -783,7 +814,10 @@ class ACEStepPipeline:
|
|
783 |
return sample
|
784 |
|
785 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
786 |
-
|
|
|
|
|
|
|
787 |
if is_repaint:
|
788 |
if i < n_min:
|
789 |
continue
|
@@ -955,7 +989,7 @@ class ACEStepPipeline:
|
|
955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
956 |
return latents
|
957 |
|
958 |
-
@spaces.GPU
|
959 |
def __call__(
|
960 |
self,
|
961 |
audio_duration: float = 60.0,
|
@@ -976,6 +1010,9 @@ class ACEStepPipeline:
|
|
976 |
oss_steps: str = None,
|
977 |
guidance_scale_text: float = 0.0,
|
978 |
guidance_scale_lyric: float = 0.0,
|
|
|
|
|
|
|
979 |
retake_seeds: list = None,
|
980 |
retake_variance: float = 0.5,
|
981 |
task: str = "text2music",
|
@@ -995,6 +1032,9 @@ class ACEStepPipeline:
|
|
995 |
|
996 |
start_time = time.time()
|
997 |
|
|
|
|
|
|
|
998 |
if not self.loaded:
|
999 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
1000 |
self.load_checkpoint(self.checkpoint_dir)
|
@@ -1053,6 +1093,14 @@ class ACEStepPipeline:
|
|
1053 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
1054 |
src_latents = self.infer_latents(src_audio_path)
|
1055 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1056 |
if task == "edit":
|
1057 |
texts = [edit_target_prompt]
|
1058 |
target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
|
@@ -1117,6 +1165,9 @@ class ACEStepPipeline:
|
|
1117 |
repaint_start=repaint_start,
|
1118 |
repaint_end=repaint_end,
|
1119 |
src_latents=src_latents,
|
|
|
|
|
|
|
1120 |
)
|
1121 |
|
1122 |
end_time = time.time()
|
@@ -1169,6 +1220,9 @@ class ACEStepPipeline:
|
|
1169 |
"src_audio_path": src_audio_path,
|
1170 |
"edit_target_prompt": edit_target_prompt,
|
1171 |
"edit_target_lyrics": edit_target_lyrics,
|
|
|
|
|
|
|
1172 |
}
|
1173 |
# save input_params_json
|
1174 |
for output_audio_path in output_paths:
|
|
|
2 |
import time
|
3 |
import os
|
4 |
import re
|
5 |
+
# import spaces
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from loguru import logger
|
|
|
522 |
target_latents = zt_edit if xt_tar is None else xt_tar
|
523 |
return target_latents
|
524 |
|
525 |
+
def add_latents_noise(
|
526 |
+
self,
|
527 |
+
gt_latents,
|
528 |
+
variance,
|
529 |
+
noise,
|
530 |
+
scheduler,
|
531 |
+
):
|
532 |
+
|
533 |
+
bsz = gt_latents.shape[0]
|
534 |
+
u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
|
535 |
+
indices = (u * scheduler.config.num_train_timesteps).long()
|
536 |
+
timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
|
537 |
+
indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
|
538 |
+
nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
|
539 |
+
sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
|
540 |
+
while len(sigma.shape) < gt_latents.ndim:
|
541 |
+
sigma = sigma.unsqueeze(-1)
|
542 |
+
noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
|
543 |
+
init_timestep = indices[0]
|
544 |
+
return noisy_image, init_timestep
|
545 |
+
|
546 |
@torch.no_grad()
|
547 |
def text2music_diffusion_process(
|
548 |
self,
|
|
|
575 |
repaint_start=0,
|
576 |
repaint_end=0,
|
577 |
src_latents=None,
|
578 |
+
audio2audio_enable=False,
|
579 |
+
ref_audio_strength=0.5,
|
580 |
+
ref_latents=None,
|
581 |
):
|
582 |
|
583 |
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
|
|
609 |
if src_latents is not None:
|
610 |
frame_length = src_latents.shape[-1]
|
611 |
|
612 |
+
if ref_latents is not None:
|
613 |
+
frame_length = ref_latents.shape[-1]
|
614 |
+
|
615 |
if len(oss_steps) > 0:
|
616 |
infer_steps = max(oss_steps)
|
617 |
scheduler.set_timesteps
|
|
|
707 |
zt_edit = x0.clone()
|
708 |
z0 = target_latents
|
709 |
|
710 |
+
init_timestep = 1000
|
711 |
+
if audio2audio_enable and ref_latents is not None:
|
712 |
+
target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
|
713 |
+
|
714 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
715 |
|
716 |
# guidance interval
|
|
|
814 |
return sample
|
815 |
|
816 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
817 |
+
|
818 |
+
if t > init_timestep:
|
819 |
+
continue
|
820 |
+
|
821 |
if is_repaint:
|
822 |
if i < n_min:
|
823 |
continue
|
|
|
989 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
990 |
return latents
|
991 |
|
992 |
+
# @spaces.GPU
|
993 |
def __call__(
|
994 |
self,
|
995 |
audio_duration: float = 60.0,
|
|
|
1010 |
oss_steps: str = None,
|
1011 |
guidance_scale_text: float = 0.0,
|
1012 |
guidance_scale_lyric: float = 0.0,
|
1013 |
+
audio2audio_enable: bool = False,
|
1014 |
+
ref_audio_strength: float = 0.5,
|
1015 |
+
ref_audio_input: str = None,
|
1016 |
retake_seeds: list = None,
|
1017 |
retake_variance: float = 0.5,
|
1018 |
task: str = "text2music",
|
|
|
1032 |
|
1033 |
start_time = time.time()
|
1034 |
|
1035 |
+
if audio2audio_enable and ref_audio_input is not None:
|
1036 |
+
task = "audio2audio"
|
1037 |
+
|
1038 |
if not self.loaded:
|
1039 |
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
1040 |
self.load_checkpoint(self.checkpoint_dir)
|
|
|
1093 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
1094 |
src_latents = self.infer_latents(src_audio_path)
|
1095 |
|
1096 |
+
ref_latents = None
|
1097 |
+
if ref_audio_input is not None and audio2audio_enable:
|
1098 |
+
assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
|
1099 |
+
assert os.path.exists(
|
1100 |
+
ref_audio_input
|
1101 |
+
), f"ref_audio_input {ref_audio_input} does not exist"
|
1102 |
+
ref_latents = self.infer_latents(ref_audio_input)
|
1103 |
+
|
1104 |
if task == "edit":
|
1105 |
texts = [edit_target_prompt]
|
1106 |
target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
|
|
|
1165 |
repaint_start=repaint_start,
|
1166 |
repaint_end=repaint_end,
|
1167 |
src_latents=src_latents,
|
1168 |
+
audio2audio_enable=audio2audio_enable,
|
1169 |
+
ref_audio_strength=ref_audio_strength,
|
1170 |
+
ref_latents=ref_latents,
|
1171 |
)
|
1172 |
|
1173 |
end_time = time.time()
|
|
|
1220 |
"src_audio_path": src_audio_path,
|
1221 |
"edit_target_prompt": edit_target_prompt,
|
1222 |
"edit_target_lyrics": edit_target_lyrics,
|
1223 |
+
"audio2audio_enable": audio2audio_enable,
|
1224 |
+
"ref_audio_strength": ref_audio_strength,
|
1225 |
+
"ref_audio_input": ref_audio_input,
|
1226 |
}
|
1227 |
# save input_params_json
|
1228 |
for output_audio_path in output_paths:
|
ui/components.py
CHANGED
@@ -71,6 +71,32 @@ def create_text2music_ui(
|
|
71 |
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
72 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
75 |
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
|
76 |
|
@@ -533,6 +559,21 @@ def create_text2music_ui(
|
|
533 |
", ".join(map(str, json_data["oss_steps"])),
|
534 |
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
535 |
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
)
|
537 |
|
538 |
sample_bnt.click(
|
@@ -556,6 +597,9 @@ def create_text2music_ui(
|
|
556 |
oss_steps,
|
557 |
guidance_scale_text,
|
558 |
guidance_scale_lyric,
|
|
|
|
|
|
|
559 |
],
|
560 |
)
|
561 |
|
@@ -580,6 +624,9 @@ def create_text2music_ui(
|
|
580 |
oss_steps,
|
581 |
guidance_scale_text,
|
582 |
guidance_scale_lyric,
|
|
|
|
|
|
|
583 |
], outputs=outputs + [input_params_json]
|
584 |
)
|
585 |
|
|
|
71 |
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
72 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
73 |
|
74 |
+
# audio2audio
|
75 |
+
audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
|
76 |
+
ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True)
|
77 |
+
ref_audio_strength = gr.Slider(
|
78 |
+
label="Refer audio strength",
|
79 |
+
minimum=0.0,
|
80 |
+
maximum=1.0,
|
81 |
+
step=0.01,
|
82 |
+
value=0.5,
|
83 |
+
elem_id="ref_audio_strength",
|
84 |
+
visible=False,
|
85 |
+
interactive=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
def toggle_ref_audio_visibility(is_checked):
|
89 |
+
return (
|
90 |
+
gr.update(visible=is_checked, elem_id="ref_audio_input"),
|
91 |
+
gr.update(visible=is_checked, elem_id="ref_audio_strength"),
|
92 |
+
)
|
93 |
+
|
94 |
+
audio2audio_enable.change(
|
95 |
+
fn=toggle_ref_audio_visibility,
|
96 |
+
inputs=[audio2audio_enable],
|
97 |
+
outputs=[ref_audio_input, ref_audio_strength],
|
98 |
+
)
|
99 |
+
|
100 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
101 |
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
|
102 |
|
|
|
559 |
", ".join(map(str, json_data["oss_steps"])),
|
560 |
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
561 |
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
562 |
+
(
|
563 |
+
json_data["audio2audio_enable"]
|
564 |
+
if "audio2audio_enable" in json_data
|
565 |
+
else False
|
566 |
+
),
|
567 |
+
(
|
568 |
+
json_data["ref_audio_strength"]
|
569 |
+
if "ref_audio_strength" in json_data
|
570 |
+
else 0.5
|
571 |
+
),
|
572 |
+
(
|
573 |
+
json_data["ref_audio_input"]
|
574 |
+
if "ref_audio_input" in json_data
|
575 |
+
else None
|
576 |
+
),
|
577 |
)
|
578 |
|
579 |
sample_bnt.click(
|
|
|
597 |
oss_steps,
|
598 |
guidance_scale_text,
|
599 |
guidance_scale_lyric,
|
600 |
+
audio2audio_enable,
|
601 |
+
ref_audio_strength,
|
602 |
+
ref_audio_input,
|
603 |
],
|
604 |
)
|
605 |
|
|
|
624 |
oss_steps,
|
625 |
guidance_scale_text,
|
626 |
guidance_scale_lyric,
|
627 |
+
audio2audio_enable,
|
628 |
+
ref_audio_strength,
|
629 |
+
ref_audio_input,
|
630 |
], outputs=outputs + [input_params_json]
|
631 |
)
|
632 |
|