Gong Junmin commited on
Commit
b0e2210
·
1 Parent(s): 5c24327

support audio2audio

Browse files
Files changed (2) hide show
  1. pipeline_ace_step.py +57 -3
  2. ui/components.py +47 -0
pipeline_ace_step.py CHANGED
@@ -2,7 +2,7 @@ import random
2
  import time
3
  import os
4
  import re
5
- import spaces
6
  import torch
7
  import torch.nn as nn
8
  from loguru import logger
@@ -522,6 +522,27 @@ class ACEStepPipeline:
522
  target_latents = zt_edit if xt_tar is None else xt_tar
523
  return target_latents
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  @torch.no_grad()
526
  def text2music_diffusion_process(
527
  self,
@@ -554,6 +575,9 @@ class ACEStepPipeline:
554
  repaint_start=0,
555
  repaint_end=0,
556
  src_latents=None,
 
 
 
557
  ):
558
 
559
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
@@ -585,6 +609,9 @@ class ACEStepPipeline:
585
  if src_latents is not None:
586
  frame_length = src_latents.shape[-1]
587
 
 
 
 
588
  if len(oss_steps) > 0:
589
  infer_steps = max(oss_steps)
590
  scheduler.set_timesteps
@@ -680,6 +707,10 @@ class ACEStepPipeline:
680
  zt_edit = x0.clone()
681
  z0 = target_latents
682
 
 
 
 
 
683
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
684
 
685
  # guidance interval
@@ -783,7 +814,10 @@ class ACEStepPipeline:
783
  return sample
784
 
785
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
786
-
 
 
 
787
  if is_repaint:
788
  if i < n_min:
789
  continue
@@ -955,7 +989,7 @@ class ACEStepPipeline:
955
  latents, _ = self.music_dcae.encode(input_audio, sr=sr)
956
  return latents
957
 
958
- @spaces.GPU
959
  def __call__(
960
  self,
961
  audio_duration: float = 60.0,
@@ -976,6 +1010,9 @@ class ACEStepPipeline:
976
  oss_steps: str = None,
977
  guidance_scale_text: float = 0.0,
978
  guidance_scale_lyric: float = 0.0,
 
 
 
979
  retake_seeds: list = None,
980
  retake_variance: float = 0.5,
981
  task: str = "text2music",
@@ -995,6 +1032,9 @@ class ACEStepPipeline:
995
 
996
  start_time = time.time()
997
 
 
 
 
998
  if not self.loaded:
999
  logger.warning("Checkpoint not loaded, loading checkpoint...")
1000
  self.load_checkpoint(self.checkpoint_dir)
@@ -1053,6 +1093,14 @@ class ACEStepPipeline:
1053
  assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
1054
  src_latents = self.infer_latents(src_audio_path)
1055
 
 
 
 
 
 
 
 
 
1056
  if task == "edit":
1057
  texts = [edit_target_prompt]
1058
  target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
@@ -1117,6 +1165,9 @@ class ACEStepPipeline:
1117
  repaint_start=repaint_start,
1118
  repaint_end=repaint_end,
1119
  src_latents=src_latents,
 
 
 
1120
  )
1121
 
1122
  end_time = time.time()
@@ -1169,6 +1220,9 @@ class ACEStepPipeline:
1169
  "src_audio_path": src_audio_path,
1170
  "edit_target_prompt": edit_target_prompt,
1171
  "edit_target_lyrics": edit_target_lyrics,
 
 
 
1172
  }
1173
  # save input_params_json
1174
  for output_audio_path in output_paths:
 
2
  import time
3
  import os
4
  import re
5
+ # import spaces
6
  import torch
7
  import torch.nn as nn
8
  from loguru import logger
 
522
  target_latents = zt_edit if xt_tar is None else xt_tar
523
  return target_latents
524
 
525
+ def add_latents_noise(
526
+ self,
527
+ gt_latents,
528
+ variance,
529
+ noise,
530
+ scheduler,
531
+ ):
532
+
533
+ bsz = gt_latents.shape[0]
534
+ u = torch.tensor([variance] * bsz, dtype=gt_latents.dtype)
535
+ indices = (u * scheduler.config.num_train_timesteps).long()
536
+ timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
537
+ indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
538
+ nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
539
+ sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
540
+ while len(sigma.shape) < gt_latents.ndim:
541
+ sigma = sigma.unsqueeze(-1)
542
+ noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
543
+ init_timestep = indices[0]
544
+ return noisy_image, init_timestep
545
+
546
  @torch.no_grad()
547
  def text2music_diffusion_process(
548
  self,
 
575
  repaint_start=0,
576
  repaint_end=0,
577
  src_latents=None,
578
+ audio2audio_enable=False,
579
+ ref_audio_strength=0.5,
580
+ ref_latents=None,
581
  ):
582
 
583
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 
609
  if src_latents is not None:
610
  frame_length = src_latents.shape[-1]
611
 
612
+ if ref_latents is not None:
613
+ frame_length = ref_latents.shape[-1]
614
+
615
  if len(oss_steps) > 0:
616
  infer_steps = max(oss_steps)
617
  scheduler.set_timesteps
 
707
  zt_edit = x0.clone()
708
  z0 = target_latents
709
 
710
+ init_timestep = 1000
711
+ if audio2audio_enable and ref_latents is not None:
712
+ target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
713
+
714
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
715
 
716
  # guidance interval
 
814
  return sample
815
 
816
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
817
+
818
+ if t > init_timestep:
819
+ continue
820
+
821
  if is_repaint:
822
  if i < n_min:
823
  continue
 
989
  latents, _ = self.music_dcae.encode(input_audio, sr=sr)
990
  return latents
991
 
992
+ # @spaces.GPU
993
  def __call__(
994
  self,
995
  audio_duration: float = 60.0,
 
1010
  oss_steps: str = None,
1011
  guidance_scale_text: float = 0.0,
1012
  guidance_scale_lyric: float = 0.0,
1013
+ audio2audio_enable: bool = False,
1014
+ ref_audio_strength: float = 0.5,
1015
+ ref_audio_input: str = None,
1016
  retake_seeds: list = None,
1017
  retake_variance: float = 0.5,
1018
  task: str = "text2music",
 
1032
 
1033
  start_time = time.time()
1034
 
1035
+ if audio2audio_enable and ref_audio_input is not None:
1036
+ task = "audio2audio"
1037
+
1038
  if not self.loaded:
1039
  logger.warning("Checkpoint not loaded, loading checkpoint...")
1040
  self.load_checkpoint(self.checkpoint_dir)
 
1093
  assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
1094
  src_latents = self.infer_latents(src_audio_path)
1095
 
1096
+ ref_latents = None
1097
+ if ref_audio_input is not None and audio2audio_enable:
1098
+ assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
1099
+ assert os.path.exists(
1100
+ ref_audio_input
1101
+ ), f"ref_audio_input {ref_audio_input} does not exist"
1102
+ ref_latents = self.infer_latents(ref_audio_input)
1103
+
1104
  if task == "edit":
1105
  texts = [edit_target_prompt]
1106
  target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
 
1165
  repaint_start=repaint_start,
1166
  repaint_end=repaint_end,
1167
  src_latents=src_latents,
1168
+ audio2audio_enable=audio2audio_enable,
1169
+ ref_audio_strength=ref_audio_strength,
1170
+ ref_latents=ref_latents,
1171
  )
1172
 
1173
  end_time = time.time()
 
1220
  "src_audio_path": src_audio_path,
1221
  "edit_target_prompt": edit_target_prompt,
1222
  "edit_target_lyrics": edit_target_lyrics,
1223
+ "audio2audio_enable": audio2audio_enable,
1224
+ "ref_audio_strength": ref_audio_strength,
1225
+ "ref_audio_input": ref_audio_input,
1226
  }
1227
  # save input_params_json
1228
  for output_audio_path in output_paths:
ui/components.py CHANGED
@@ -71,6 +71,32 @@ def create_text2music_ui(
71
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
72
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
75
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
76
 
@@ -533,6 +559,21 @@ def create_text2music_ui(
533
  ", ".join(map(str, json_data["oss_steps"])),
534
  json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
535
  json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  )
537
 
538
  sample_bnt.click(
@@ -556,6 +597,9 @@ def create_text2music_ui(
556
  oss_steps,
557
  guidance_scale_text,
558
  guidance_scale_lyric,
 
 
 
559
  ],
560
  )
561
 
@@ -580,6 +624,9 @@ def create_text2music_ui(
580
  oss_steps,
581
  guidance_scale_text,
582
  guidance_scale_lyric,
 
 
 
583
  ], outputs=outputs + [input_params_json]
584
  )
585
 
 
71
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
72
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
73
 
74
+ # audio2audio
75
+ audio2audio_enable = gr.Checkbox(label="Enable Audio2Audio", value=False, info="Check to enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox")
76
+ ref_audio_input = gr.Audio(type="filepath", label="Reference Audio (for Audio2Audio)", visible=False, elem_id="ref_audio_input", show_download_button=True)
77
+ ref_audio_strength = gr.Slider(
78
+ label="Refer audio strength",
79
+ minimum=0.0,
80
+ maximum=1.0,
81
+ step=0.01,
82
+ value=0.5,
83
+ elem_id="ref_audio_strength",
84
+ visible=False,
85
+ interactive=True,
86
+ )
87
+
88
+ def toggle_ref_audio_visibility(is_checked):
89
+ return (
90
+ gr.update(visible=is_checked, elem_id="ref_audio_input"),
91
+ gr.update(visible=is_checked, elem_id="ref_audio_strength"),
92
+ )
93
+
94
+ audio2audio_enable.change(
95
+ fn=toggle_ref_audio_visibility,
96
+ inputs=[audio2audio_enable],
97
+ outputs=[ref_audio_input, ref_audio_strength],
98
+ )
99
+
100
  prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, value=TAG_DEFAULT, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
101
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, value=LYRIC_DEFAULT, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
102
 
 
559
  ", ".join(map(str, json_data["oss_steps"])),
560
  json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
561
  json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
562
+ (
563
+ json_data["audio2audio_enable"]
564
+ if "audio2audio_enable" in json_data
565
+ else False
566
+ ),
567
+ (
568
+ json_data["ref_audio_strength"]
569
+ if "ref_audio_strength" in json_data
570
+ else 0.5
571
+ ),
572
+ (
573
+ json_data["ref_audio_input"]
574
+ if "ref_audio_input" in json_data
575
+ else None
576
+ ),
577
  )
578
 
579
  sample_bnt.click(
 
597
  oss_steps,
598
  guidance_scale_text,
599
  guidance_scale_lyric,
600
+ audio2audio_enable,
601
+ ref_audio_strength,
602
+ ref_audio_input,
603
  ],
604
  )
605
 
 
624
  oss_steps,
625
  guidance_scale_text,
626
  guidance_scale_lyric,
627
+ audio2audio_enable,
628
+ ref_audio_strength,
629
+ ref_audio_input,
630
  ], outputs=outputs + [input_params_json]
631
  )
632