SoloAudio / app.py
OpenSound's picture
Update app.py
fdeb5ca verified
raw
history blame
9.35 kB
import gradio as gr
import spaces
import yaml
import torch
# import librosa
import torchaudio
from diffusers import DDIMScheduler
from transformers import AutoProcessor, ClapModel, ClapConfig
from model.udit import UDiT
from vae_modules.autoencoder_wrapper import Autoencoder
import numpy as np
from huggingface_hub import hf_hub_download
clap_bin_path = hf_hub_download("laion/larger_clap_general", "pytorch_model.bin")
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="laion/larger_clap_general",
# local_dir="./larger_clap_general",
# local_dir_use_symlinks=False)
diffusion_config = './config/SoloAudio.yaml'
diffusion_ckpt = './pretrained_models/soloaudio_v2.pt'
autoencoder_path = './pretrained_models/audio-vae.pt'
uncond_path = './pretrained_models/uncond.npz'
sample_rate = 24000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open(diffusion_config, 'r') as fp:
diff_config = yaml.safe_load(fp)
v_prediction = diff_config["ddim"]["v_prediction"]
processor = AutoProcessor.from_pretrained('laion/larger_clap_general')
clap_config = ClapConfig.from_pretrained("laion/larger_clap_general")
clapmodel = ClapModel(clap_config)
clap_ckpt = torch.load(clap_bin_path, map_location='cpu')
clapmodel.load_state_dict(clap_ckpt)
clapmodel.to(device)
# clapmodel = ClapModel.from_pretrained("laion/larger_clap_general").to(device)
autoencoder = Autoencoder(autoencoder_path, 'stable_vae', quantization_first=True)
autoencoder.eval()
autoencoder.to(device)
unet = UDiT(
**diff_config['diffwrap']['UDiT']
).to(device)
unet.load_state_dict(torch.load(diffusion_ckpt)['model'])
unet.eval()
if v_prediction:
print('v prediction')
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
else:
print('noise prediction')
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
# these steps reset dtype of noise_scheduler params
latents = torch.randn((1, 128, 128),
device=device)
noise = torch.randn(latents.shape).to(latents.device)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps,
(noise.shape[0],),
device=latents.device).long()
_ = scheduler.add_noise(latents, noise, timesteps)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@spaces.GPU
def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_scale=False, guidance_rescale=0.0,):
with torch.no_grad():
scheduler.set_timesteps(ddim_steps)
generator = torch.Generator(device=device).manual_seed(seed)
# init noise
noise = torch.randn(mixture.shape, generator=generator, device=device)
pred = noise
for t in scheduler.timesteps:
pred = scheduler.scale_model_input(pred, t)
if guidance_scale:
uncond = torch.tensor(np.load(uncond_path)['arr_0']).unsqueeze(0).to(device)
pred_combined = torch.cat([pred, pred], dim=0)
mixture_combined = torch.cat([mixture, mixture], dim=0)
timbre_combined = torch.cat([timbre, uncond], dim=0)
output_combined = unet(x=pred_combined, timesteps=t, mixture=mixture_combined, timbre=timbre_combined)
output_pos, output_neg = torch.chunk(output_combined, 2, dim=0)
model_output = output_neg + guidance_scale * (output_pos - output_neg)
if guidance_rescale > 0.0:
# avoid overexposed
model_output = rescale_noise_cfg(model_output, output_pos,
guidance_rescale=guidance_rescale)
else:
model_output = unet(x=pred, timesteps=t, mixture=mixture, timbre=timbre)
pred = scheduler.step(model_output=model_output, timestep=t, sample=pred,
eta=eta, generator=generator).prev_sample
pred = autoencoder(embedding=pred).squeeze(1)
return pred
@spaces.GPU
def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
with torch.no_grad():
# mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
mixture, sr = torchaudio.load(gt_file_input)
if sr != sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
mixture = resampler(mixture)
sr = sample_rate
# Check the length of the audio in samples
current_length = len(mixture)
target_length = sample_rate * 10
# Cut or pad the audio to match the target length
if current_length > target_length:
# Trim the audio if it's longer than the target length
mixture = mixture[:target_length]
elif current_length < target_length:
# Pad the audio with zeros if it's shorter than the target length
padding = target_length - current_length
mixture = np.pad(mixture, (0, padding), mode='constant')
mixture = torch.tensor(mixture).unsqueeze(0).to(device)
mixture = autoencoder(audio=mixture.unsqueeze(1))
text_inputs = processor(
text=[text_input],
max_length=10, # Fixed length for text
padding='max_length', # Pad text to max length
truncation=True, # Truncate text if it's longer than max length
return_tensors="pt"
)
inputs = {
"input_ids": text_inputs["input_ids"][0].unsqueeze(0), # Text input IDs
"attention_mask": text_inputs["attention_mask"][0].unsqueeze(0), # Attention mask for text
}
inputs = {key: value.to(device) for key, value in inputs.items()}
timbre = clapmodel.get_text_features(**inputs)
pred = sample_diffusion(mixture, timbre, num_infer_steps, eta, seed, guidance_scale, guidance_rescale)
return sample_rate, pred.squeeze().cpu().numpy()
# CSS styling (optional)
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
"""
# Gradio Blocks layout
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SoloAudio: Target Sound Extraction with Language-oriented Audio Diffusion Transformer
Adjust advanced settings for more control. This space only supports a 10-second audio input now.
Learn more about 🟣**SoloAudio** on the [SoloAudio Homepage](https://wanghelin1997.github.io/SoloAudio-Demo/).
""")
with gr.Tab("Target Sound Extraction"):
# Basic Input: Text prompt
with gr.Row():
gt_file_input = gr.Audio(label="Upload Audio to Extract", type="filepath", value="demo/0_mix.wav", scale=3)
text_input = gr.Textbox(
label="Text Prompt",
show_label=True,
max_lines=2,
placeholder="Enter your prompt",
container=True,
value="The sound of gunshot",
scale=2
)
# Run button
run_button = gr.Button("Extract", scale=1)
# Output Component
result = gr.Audio(label="Extracted Audio", type="numpy")
# Advanced settings in an Accordion
with gr.Accordion("Advanced Settings", open=False):
# Audio Length
guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=3.0, label="Guidance Scale")
guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0., label="Guidance Rescale")
num_infer_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Eta")
seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed")
# Define the trigger and input-output linking for generation
run_button.click(
fn=tse,
inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale],
outputs=[result]
)
text_input.submit(fn=tse,
inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale],
outputs=[result]
)
# Launch the Gradio demo
demo.launch()