|
import os
|
|
import sys
|
|
import time
|
|
import tqdm
|
|
import torch
|
|
import logging
|
|
import librosa
|
|
import argparse
|
|
import scipy.signal
|
|
import logging.handlers
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
|
|
from torch import inference_mode
|
|
from distutils.util import strtobool
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.configs.config import Config
|
|
from main.library.audioldm2.utils import load_audio
|
|
from main.library.audioldm2.models import load_model
|
|
|
|
config = Config()
|
|
translations = config.translations
|
|
logger = logging.getLogger(__name__)
|
|
logger.propagate = False
|
|
|
|
for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]:
|
|
logging.getLogger(l).setLevel(logging.ERROR)
|
|
|
|
if logger.hasHandlers(): logger.handlers.clear()
|
|
else:
|
|
console_handler = logging.StreamHandler()
|
|
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
console_handler.setFormatter(console_formatter)
|
|
console_handler.setLevel(logging.INFO)
|
|
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
|
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
file_handler.setFormatter(file_formatter)
|
|
file_handler.setLevel(logging.DEBUG)
|
|
logger.addHandler(console_handler)
|
|
logger.addHandler(file_handler)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input_path", type=str, required=True)
|
|
parser.add_argument("--output_path", type=str, default="./output.wav")
|
|
parser.add_argument("--export_format", type=str, default="wav")
|
|
parser.add_argument("--sample_rate", type=int, default=44100)
|
|
parser.add_argument("--audioldm_model", type=str, default="audioldm2-music")
|
|
parser.add_argument("--source_prompt", type=str, default="")
|
|
parser.add_argument("--target_prompt", type=str, default="")
|
|
parser.add_argument("--steps", type=int, default=200)
|
|
parser.add_argument("--cfg_scale_src", type=float, default=3.5)
|
|
parser.add_argument("--cfg_scale_tar", type=float, default=12)
|
|
parser.add_argument("--t_start", type=int, default=45)
|
|
parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False)
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute
|
|
|
|
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute}
|
|
|
|
for key, value in log_data.items():
|
|
logger.debug(f"{key}: {value}")
|
|
|
|
start_time = time.time()
|
|
logger.info(translations["start_edit"].format(input_path=input_path))
|
|
pid_path = os.path.join("assets", "audioldm2_pid.txt")
|
|
with open(pid_path, "w") as pid_file:
|
|
pid_file.write(str(os.getpid()))
|
|
|
|
try:
|
|
edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format)
|
|
except Exception as e:
|
|
logger.error(translations["error_edit"].format(e=e))
|
|
import traceback
|
|
logger.debug(traceback.format_exc())
|
|
|
|
logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
|
|
|
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
|
|
with inference_mode():
|
|
w0 = ldm_stable.vae_encode(x0)
|
|
|
|
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute)
|
|
return zs, wts, extra_info
|
|
|
|
def low_pass_filter(audio, cutoff=7500, sr=16000):
|
|
b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low')
|
|
return scipy.signal.filtfilt(b, a, audio)
|
|
|
|
def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"):
|
|
tstart = torch.tensor(tstart, dtype=torch.int32)
|
|
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute)
|
|
|
|
with inference_mode():
|
|
x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32))
|
|
|
|
if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :]
|
|
|
|
with torch.no_grad():
|
|
audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32))
|
|
|
|
audio = audio.float().squeeze().cpu().numpy()
|
|
orig_sr = 16000
|
|
|
|
if sr != 16000 and sr > 0:
|
|
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq")
|
|
orig_sr = sr
|
|
|
|
audio = low_pass_filter(audio, 7500, orig_sr)
|
|
|
|
sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format)
|
|
return output_audio
|
|
|
|
def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"):
|
|
ldm_stable = load_model(model_id, device=device)
|
|
ldm_stable.model.scheduler.set_timesteps(steps, device=device)
|
|
x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device)
|
|
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute)
|
|
|
|
return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format)
|
|
|
|
def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True):
|
|
if len(prompts) > 1 or prompts[0] != "":
|
|
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
|
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
|
else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False)
|
|
|
|
timesteps = model.model.scheduler.timesteps.to(model.device)
|
|
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
|
|
|
|
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
|
|
|
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
|
|
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
|
extra_info = [None] * len(zs)
|
|
|
|
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
|
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
|
|
|
xt = x0
|
|
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "")
|
|
|
|
for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"):
|
|
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
|
|
xt = xts[idx + 1][None]
|
|
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
|
|
|
with torch.no_grad():
|
|
if save_compute and prompts[0] != "":
|
|
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
|
out, cond_out = comb_out.sample.chunk(2, dim=0)
|
|
else:
|
|
out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
|
if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
|
|
|
if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
|
|
else: noise_pred = out
|
|
|
|
xtm1 = xts[idx][None]
|
|
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order)
|
|
zs[idx] = z
|
|
xts[idx] = xtm1
|
|
extra_info[idx] = extra
|
|
|
|
if zs is not None: zs[0] = torch.zeros_like(zs[0])
|
|
return xt, zs, xts, extra_info
|
|
|
|
def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True):
|
|
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
|
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
|
xt = xT[tstart.max()].unsqueeze(0)
|
|
|
|
if etas is None: etas = 0
|
|
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
|
|
|
assert len(etas) == model.model.scheduler.num_inference_steps
|
|
timesteps = model.model.scheduler.timesteps.to(model.device)
|
|
|
|
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
|
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
|
|
|
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute)
|
|
|
|
for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"):
|
|
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
|
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
|
|
|
with torch.no_grad():
|
|
if save_compute:
|
|
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
|
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
|
|
else:
|
|
uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
|
cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
|
|
|
z = zs[idx] if zs is not None else None
|
|
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
|
|
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order)
|
|
|
|
return xt, zs
|
|
|
|
if __name__ == "__main__": main() |