|
import os |
|
import sys |
|
import time |
|
import tqdm |
|
import torch |
|
import logging |
|
import librosa |
|
import argparse |
|
import scipy.signal |
|
import logging.handlers |
|
|
|
import numpy as np |
|
import soundfile as sf |
|
|
|
from torch import inference_mode |
|
from distutils.util import strtobool |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
from main.configs.config import Config |
|
from main.library.audioldm2.utils import load_audio |
|
from main.library.audioldm2.models import load_model |
|
|
|
config = Config() |
|
translations = config.translations |
|
logger = logging.getLogger(__name__) |
|
logger.propagate = False |
|
|
|
for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]: |
|
logging.getLogger(l).setLevel(logging.ERROR) |
|
|
|
if logger.hasHandlers(): logger.handlers.clear() |
|
else: |
|
console_handler = logging.StreamHandler() |
|
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
|
console_handler.setFormatter(console_formatter) |
|
console_handler.setLevel(logging.INFO) |
|
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') |
|
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
|
file_handler.setFormatter(file_formatter) |
|
file_handler.setLevel(logging.DEBUG) |
|
logger.addHandler(console_handler) |
|
logger.addHandler(file_handler) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--input_path", type=str, required=True) |
|
parser.add_argument("--output_path", type=str, default="./output.wav") |
|
parser.add_argument("--export_format", type=str, default="wav") |
|
parser.add_argument("--sample_rate", type=int, default=44100) |
|
parser.add_argument("--audioldm_model", type=str, default="audioldm2-music") |
|
parser.add_argument("--source_prompt", type=str, default="") |
|
parser.add_argument("--target_prompt", type=str, default="") |
|
parser.add_argument("--steps", type=int, default=200) |
|
parser.add_argument("--cfg_scale_src", type=float, default=3.5) |
|
parser.add_argument("--cfg_scale_tar", type=float, default=12) |
|
parser.add_argument("--t_start", type=int, default=45) |
|
parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False) |
|
|
|
return parser.parse_args() |
|
|
|
def main(): |
|
args = parse_arguments() |
|
input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute |
|
|
|
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute} |
|
|
|
for key, value in log_data.items(): |
|
logger.debug(f"{key}: {value}") |
|
|
|
start_time = time.time() |
|
logger.info(translations["start_edit"].format(input_path=input_path)) |
|
pid_path = os.path.join("assets", "audioldm2_pid.txt") |
|
with open(pid_path, "w") as pid_file: |
|
pid_file.write(str(os.getpid())) |
|
|
|
try: |
|
edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format) |
|
except Exception as e: |
|
logger.error(translations["error_edit"].format(e=e)) |
|
import traceback |
|
logger.debug(traceback.format_exc()) |
|
|
|
logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format))) |
|
|
|
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute): |
|
with inference_mode(): |
|
w0 = ldm_stable.vae_encode(x0) |
|
|
|
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute) |
|
return zs, wts, extra_info |
|
|
|
def low_pass_filter(audio, cutoff=7500, sr=16000): |
|
b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low') |
|
return scipy.signal.filtfilt(b, a, audio) |
|
|
|
def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"): |
|
tstart = torch.tensor(tstart, dtype=torch.int32) |
|
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute) |
|
|
|
with inference_mode(): |
|
x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32)) |
|
|
|
if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :] |
|
|
|
with torch.no_grad(): |
|
audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32)) |
|
|
|
audio = audio.float().squeeze().cpu().numpy() |
|
orig_sr = 16000 |
|
|
|
if sr != 16000 and sr > 0: |
|
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq") |
|
orig_sr = sr |
|
|
|
audio = low_pass_filter(audio, 7500, orig_sr) |
|
|
|
sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format) |
|
return output_audio |
|
|
|
def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"): |
|
ldm_stable = load_model(model_id, device=device) |
|
ldm_stable.model.scheduler.set_timesteps(steps, device=device) |
|
x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device) |
|
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute) |
|
|
|
return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format) |
|
|
|
def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True): |
|
if len(prompts) > 1 or prompts[0] != "": |
|
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts) |
|
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) |
|
else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False) |
|
|
|
timesteps = model.model.scheduler.timesteps.to(model.device) |
|
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps) |
|
|
|
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps |
|
|
|
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps) |
|
zs = torch.zeros(size=variance_noise_shape, device=model.device) |
|
extra_info = [None] * len(zs) |
|
|
|
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)} |
|
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)} |
|
|
|
xt = x0 |
|
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "") |
|
|
|
for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"): |
|
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1 |
|
xt = xts[idx + 1][None] |
|
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32) |
|
|
|
with torch.no_grad(): |
|
if save_compute and prompts[0] != "": |
|
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None) |
|
out, cond_out = comb_out.sample.chunk(2, dim=0) |
|
else: |
|
out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample |
|
if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample |
|
|
|
if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0) |
|
else: noise_pred = out |
|
|
|
xtm1 = xts[idx][None] |
|
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order) |
|
zs[idx] = z |
|
xts[idx] = xtm1 |
|
extra_info[idx] = extra |
|
|
|
if zs is not None: zs[0] = torch.zeros_like(zs[0]) |
|
return xt, zs, xts, extra_info |
|
|
|
def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True): |
|
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts) |
|
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) |
|
xt = xT[tstart.max()].unsqueeze(0) |
|
|
|
if etas is None: etas = 0 |
|
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps |
|
|
|
assert len(etas) == model.model.scheduler.num_inference_steps |
|
timesteps = model.model.scheduler.timesteps.to(model.device) |
|
|
|
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} |
|
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} |
|
|
|
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute) |
|
|
|
for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"): |
|
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1) |
|
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32) |
|
|
|
with torch.no_grad(): |
|
if save_compute: |
|
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None) |
|
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0) |
|
else: |
|
uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample |
|
cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample |
|
|
|
z = zs[idx] if zs is not None else None |
|
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0) |
|
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order) |
|
|
|
return xt, zs |
|
|
|
if __name__ == "__main__": main() |