import gradio as gr import spaces import yaml import random import argparse import os import torch import librosa from tqdm import tqdm from diffusers import DDIMScheduler from solospeech.model.solospeech.conditioners import SoloSpeech_TSE from solospeech.model.solospeech.conditioners import SoloSpeech_TSR from solospeech.scripts.solospeech.utils import save_audio import shutil from solospeech.vae_modules.autoencoder_wrapper import Autoencoder import pandas as pd from speechbrain.pretrained.interfaces import Pretrained from solospeech.corrector.fastgeco.model import ScoreModel from solospeech.corrector.geco.util.other import pad_spec from huggingface_hub import snapshot_download import time class Encoder(Pretrained): MODULES_NEEDED = [ "compute_features", "mean_var_norm", "embedding_model" ] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def encode_batch(self, wavs, wav_lens=None, normalize=False): # Manage single waveforms in input if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) # Assign full length if wav_lens is not assigned if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) # Storing waveform in the specified device wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = wavs.float() # Computing features and embeddings feats = self.mods.compute_features(wavs) feats = self.mods.mean_var_norm(feats, wav_lens) embeddings = self.mods.embedding_model(feats, wav_lens) if normalize: embeddings = self.hparams.mean_var_norm_emb( embeddings, torch.ones(embeddings.shape[0], device=self.device) ) return embeddings parser = argparse.ArgumentParser() # pre-trained model path parser.add_argument('--eta', type=int, default=0) parser.add_argument("--num_infer_steps", type=int, default=200) parser.add_argument('--sample-rate', type=int, default=16000) # random seed parser.add_argument('--random-seed', type=int, default=42, help="Fixed seed") args = parser.parse_args() print("Downloading model from Huggingface...") local_dir = snapshot_download( repo_id="OpenSound/SoloSpeech-models" ) args.tse_config = os.path.join(local_dir, "config_extractor.yaml") args.tsr_config = os.path.join(local_dir, "config_tsr.yaml") args.vae_config = os.path.join(local_dir, "config_compressor.json") args.autoencoder_path = os.path.join(local_dir, "compressor.ckpt") args.tse_ckpt = os.path.join(local_dir, "extractor.pt") args.tsr_ckpt = os.path.join(local_dir, "tsr.pt") args.geco_ckpt = os.path.join(local_dir, "corrector.ckpt") device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") # load config print("Loading models...") with open(args.tse_config, 'r') as fp: args.tse_config = yaml.safe_load(fp) with open(args.tsr_config, 'r') as fp: args.tsr_config = yaml.safe_load(fp) args.v_prediction = args.tse_config["ddim"]["v_prediction"] # load compressor autoencoder = Autoencoder(args.autoencoder_path, args.vae_config, 'stft_vae', quantization_first=True) autoencoder.eval() autoencoder.to(device) # load extractor tse_model = SoloSpeech_TSE( args.tse_config['diffwrap']['UDiT'], args.tse_config['diffwrap']['ViT'], ).to(device) tse_model.load_state_dict(torch.load(args.tse_ckpt)['model']) tse_model.eval() # load tsr model tsr_model = SoloSpeech_TSR( args.tsr_config['diffwrap']['UDiT'] ).to(device) tsr_model.load_state_dict(torch.load(args.tsr_ckpt)['model']) tsr_model.eval() # load corrector geco_model = ScoreModel.load_from_checkpoint( args.geco_ckpt, batch_size=1, num_workers=0, kwargs=dict(gpu=False) ) geco_model.eval(no_ema=False) geco_model.cuda() # load sid model ecapatdnn_model = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2") cosine_sim = torch.nn.CosineSimilarity(dim=-1) # load diffusion tools noise_scheduler = DDIMScheduler(**args.tse_config["ddim"]['diffusers']) # these steps reset dtype of noise_scheduler params latents = torch.randn((1, 128, 128), device=device) noise = torch.randn(latents.shape).to(device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (noise.shape[0],), device=latents.device).long() _ = noise_scheduler.add_noise(latents, noise, timesteps) @spaces.GPU def sample_diffusion(tse_model, tsr_model, autoencoder, std, scheduler, device, mixture=None, reference=None, lengths=None, reference_lengths=None, ddim_steps=50, eta=0, seed=2025 ): with torch.no_grad(): generator = torch.Generator(device=device).manual_seed(seed) scheduler.set_timesteps(ddim_steps) tse_pred = torch.randn(mixture.shape, generator=generator, device=device) tsr_pred = torch.randn(mixture.shape, generator=generator, device=device) for t in scheduler.timesteps: tse_pred = scheduler.scale_model_input(tse_pred, t) model_output, _ = tse_model( x=tse_pred, timesteps=t, mixture=mixture, reference=reference, x_len=lengths, ref_len=reference_lengths ) tse_pred = scheduler.step(model_output=model_output, timestep=t, sample=tse_pred, eta=eta, generator=generator).prev_sample for t in scheduler.timesteps: tsr_pred = scheduler.scale_model_input(tsr_pred, t) model_output, _ = tsr_model( x=tsr_pred, timesteps=t, mixture=mixture, reference=tse_pred, x_len=lengths, ) tsr_pred = scheduler.step(model_output=model_output, timestep=t, sample=tsr_pred, eta=eta, generator=generator).prev_sample tse_pred = autoencoder(embedding=tse_pred.transpose(2,1), std=std).squeeze(1) tsr_pred = autoencoder(embedding=tsr_pred.transpose(2,1), std=std).squeeze(1) return tse_pred, tsr_pred @spaces.GPU def tse(test_wav, enroll_wav): print("Start Extraction...") start_time = time.time() mixture, _ = librosa.load(test_wav, sr=16000) reference, _ = librosa.load(enroll_wav, sr=16000) reference_wav = reference reference = torch.tensor(reference).unsqueeze(0).to(device) with torch.no_grad(): # compressor reference, _ = autoencoder(audio=reference.unsqueeze(1)) reference_lengths = torch.LongTensor([reference.shape[-1]]).to(device) mixture_input = torch.tensor(mixture).unsqueeze(0).to(device) mixture_wav = mixture_input mixture_input, std = autoencoder(audio=mixture_input.unsqueeze(1)) lengths = torch.LongTensor([mixture_input.shape[-1]]).to(device) # extractor tse_pred, tsr_pred = sample_diffusion(tse_model, tsr_model, autoencoder, std, noise_scheduler, device, mixture_input.transpose(2,1), reference.transpose(2,1), lengths, reference_lengths, ddim_steps=args.num_infer_steps, eta=args.eta, seed=args.random_seed) ecapatdnn_embedding1 = ecapatdnn_model.encode_batch(tse_pred.squeeze()).squeeze() ecapatdnn_embedding2 = ecapatdnn_model.encode_batch(tsr_pred.squeeze()).squeeze() ecapatdnn_embedding3 = ecapatdnn_model.encode_batch(torch.tensor(reference_wav)).squeeze() sim1 = cosine_sim(ecapatdnn_embedding1, ecapatdnn_embedding3).item() sim2 = cosine_sim(ecapatdnn_embedding2, ecapatdnn_embedding3).item() pred = tse_pred if sim1 > sim2 else tsr_pred # corrector min_leng = min(pred.shape[-1], mixture_wav.shape[-1]) x = pred[...,:min_leng] m = mixture_wav[...,:min_leng] norm_factor = m.abs().max() x = x / norm_factor m = m / norm_factor X = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(x.cuda())), 0) X = pad_spec(X) M = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(m.cuda())), 0) M = pad_spec(M) timesteps = torch.linspace(0.5, 0.03, 1, device=M.device) std = geco_model.sde._std(0.5*torch.ones((M.shape[0],), device=M.device)) z = torch.randn_like(M) X_t = M + z * std[:, None, None, None] for idx in range(len(timesteps)): t = timesteps[idx] if idx != len(timesteps) - 1: dt = t - timesteps[idx+1] else: dt = timesteps[-1] with torch.no_grad(): f, g = geco_model.sde.sde(X_t, t, M) vec_t = torch.ones(M.shape[0], device=M.device) * t mean_x_tm1 = X_t - (f - g**2*geco_model.forward(X_t, vec_t, M, X, vec_t[:,None,None,None]))*dt if idx == len(timesteps) - 1: X_t = mean_x_tm1 break z = torch.randn_like(X) X_t = mean_x_tm1 + z*g*torch.sqrt(dt) sample = X_t sample = sample.squeeze() x_hat = geco_model.to_audio(sample.squeeze(), min_leng) x_hat = x_hat * norm_factor / x_hat.abs().max() x_hat = x_hat.detach().cpu().squeeze().numpy() end_time = time.time() audio_len = x_hat.shape[-1] / 16000 rtf = (end_time-start_time)/audio_len print(f"RTF: {rtf:.4f}") return (16000, x_hat) @spaces.GPU def process_audio(test_wav, enroll_wav): result = tse(test_wav, enroll_wav) return result # List of demo audio files demo_audio_files = [ ("Test Demo 1", "test1.wav", "test1_enroll.wav"), ("Test Demo 2", "test2.wav", "test2_enroll.wav") ] def update_audio_input(choice): return choice # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # 🎸 SoloSpeech: Enhancing Intelligibility and Quality in Target Speech Extraction through a Cascaded Generative Pipeline Extract the target voice from mixture speech given an enrollment speech. Learn more about **SoloSpeech** on the [SoloSpeech Repo](https://github.com/WangHelin1997/SoloSpeech/). """) with gr.Tab("Target Speech Extraction"): with gr.Row(): mixture_input = gr.Audio(label="Upload Mixture Audio", type="filepath", value="test2.wav") enroll_input = gr.Audio(label="Upload Enrollment Audio", type="filepath", value="test2_enroll.wav") with gr.Row(): demo_selector = gr.Dropdown( label="Select Test Demo", choices=[name for name, _, _ in demo_audio_files], value="Test Demo 2" ) extract_button = gr.Button("Extract", scale=1) with gr.Row(): result = gr.Audio(label="Extracted Speech", type="numpy") # Update audio inputs when selecting from dropdown def update_audio_inputs(choice): for name, mixture_path, enroll_path in demo_audio_files: if name == choice: return mixture_path, enroll_path return None, None demo_selector.change( fn=update_audio_inputs, inputs=demo_selector, outputs=[mixture_input, enroll_input] ) extract_button.click( fn=process_audio, inputs=[mixture_input, enroll_input], outputs=[result] ) # Launch the Gradio demo demo.launch()