OpenSound commited on
Commit
ad315e1
·
verified ·
1 Parent(s): db6a74e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import yaml
4
+ import random
5
+ import argparse
6
+ import os
7
+ import torch
8
+ import librosa
9
+ from tqdm import tqdm
10
+ from diffusers import DDIMScheduler
11
+ from solospeech.model.solospeech.conditioners import SoloSpeech_TSE
12
+ from solospeech.model.solospeech.conditioners import SoloSpeech_TSR
13
+ from solospeech.scripts.solospeech.utils import save_audio
14
+ import shutil
15
+ from solospeech.vae_modules.autoencoder_wrapper import Autoencoder
16
+ import pandas as pd
17
+ from speechbrain.pretrained.interfaces import Pretrained
18
+ from solospeech.corrector.fastgeco.model import ScoreModel
19
+ from solospeech.corrector.geco.util.other import pad_spec
20
+ from huggingface_hub import snapshot_download
21
+ import time
22
+
23
+ parser = argparse.ArgumentParser()
24
+ # pre-trained model path
25
+ parser.add_argument('--eta', type=int, default=0)
26
+ parser.add_argument("--num_infer_steps", type=int, default=200)
27
+ parser.add_argument('--sample-rate', type=int, default=16000)
28
+ # random seed
29
+ parser.add_argument('--random-seed', type=int, default=42, help="Fixed seed")
30
+ args = parser.parse_args()
31
+
32
+ print("Downloading model from Huggingface...")
33
+ local_dir = snapshot_download(
34
+ repo_id="OpenSound/SoloSpeech-models"
35
+ )
36
+ args.tse_config = os.path.join(local_dir, "config_extractor.yaml")
37
+ args.tsr_config = os.path.join(local_dir, "config_tsr.yaml")
38
+ args.vae_config = os.path.join(local_dir, "config_compressor.json")
39
+ args.autoencoder_path = os.path.join(local_dir, "compressor.ckpt")
40
+ args.tse_ckpt = os.path.join(local_dir, "extractor.pt")
41
+ args.tsr_ckpt = os.path.join(local_dir, "tsr.pt")
42
+ args.geco_ckpt = os.path.join(local_dir, "corrector.ckpt")
43
+
44
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
45
+ print(f"Device: {device}")
46
+ # load config
47
+ print("Loading models...")
48
+ with open(args.tse_config, 'r') as fp:
49
+ args.tse_config = yaml.safe_load(fp)
50
+ with open(args.tsr_config, 'r') as fp:
51
+ args.tsr_config = yaml.safe_load(fp)
52
+ args.v_prediction = args.tse_config["ddim"]["v_prediction"]
53
+ # load compressor
54
+ autoencoder = Autoencoder(args.autoencoder_path, args.vae_config, 'stft_vae', quantization_first=True)
55
+ autoencoder.eval()
56
+ autoencoder.to(device)
57
+ # load extractor
58
+ tse_model = SoloSpeech_TSE(
59
+ args.tse_config['diffwrap']['UDiT'],
60
+ args.tse_config['diffwrap']['ViT'],
61
+ ).to(device)
62
+ tse_model.load_state_dict(torch.load(args.tse_ckpt)['model'])
63
+ tse_model.eval()
64
+ # load tsr model
65
+ tsr_model = SoloSpeech_TSR(
66
+ args.tsr_config['diffwrap']['UDiT']
67
+ ).to(device)
68
+ tsr_model.load_state_dict(torch.load(args.tsr_ckpt)['model'])
69
+ tsr_model.eval()
70
+ # load corrector
71
+ geco_model = ScoreModel.load_from_checkpoint(
72
+ args.geco_ckpt,
73
+ batch_size=1, num_workers=0, kwargs=dict(gpu=False)
74
+ )
75
+ geco_model.eval(no_ema=False)
76
+ geco_model.cuda()
77
+ # load sid model
78
+ ecapatdnn_model = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
79
+ cosine_sim = torch.nn.CosineSimilarity(dim=-1)
80
+ # load diffusion tools
81
+ noise_scheduler = DDIMScheduler(**args.tse_config["ddim"]['diffusers'])
82
+ # these steps reset dtype of noise_scheduler params
83
+ latents = torch.randn((1, 128, 128),
84
+ device=device)
85
+ noise = torch.randn(latents.shape).to(device)
86
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
87
+ (noise.shape[0],),
88
+ device=latents.device).long()
89
+ _ = noise_scheduler.add_noise(latents, noise, timesteps)
90
+
91
+
92
+ class Encoder(Pretrained):
93
+
94
+ MODULES_NEEDED = [
95
+ "compute_features",
96
+ "mean_var_norm",
97
+ "embedding_model"
98
+ ]
99
+
100
+ def __init__(self, *args, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+
103
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
104
+ # Manage single waveforms in input
105
+ if len(wavs.shape) == 1:
106
+ wavs = wavs.unsqueeze(0)
107
+
108
+ # Assign full length if wav_lens is not assigned
109
+ if wav_lens is None:
110
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
111
+
112
+ # Storing waveform in the specified device
113
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
114
+ wavs = wavs.float()
115
+
116
+ # Computing features and embeddings
117
+ feats = self.mods.compute_features(wavs)
118
+ feats = self.mods.mean_var_norm(feats, wav_lens)
119
+ embeddings = self.mods.embedding_model(feats, wav_lens)
120
+ if normalize:
121
+ embeddings = self.hparams.mean_var_norm_emb(
122
+ embeddings,
123
+ torch.ones(embeddings.shape[0], device=self.device)
124
+ )
125
+ return embeddings
126
+
127
+
128
+
129
+ @spaces.GPU
130
+ def sample_diffusion(tse_model, tsr_model, autoencoder, std, scheduler, device,
131
+ mixture=None, reference=None, lengths=None, reference_lengths=None,
132
+ ddim_steps=50, eta=0, seed=2025
133
+ ):
134
+ with torch.no_grad():
135
+ generator = torch.Generator(device=device).manual_seed(seed)
136
+ scheduler.set_timesteps(ddim_steps)
137
+ tse_pred = torch.randn(mixture.shape, generator=generator, device=device)
138
+ tsr_pred = torch.randn(mixture.shape, generator=generator, device=device)
139
+
140
+ for t in scheduler.timesteps:
141
+ tse_pred = scheduler.scale_model_input(tse_pred, t)
142
+ model_output, _ = tse_model(
143
+ x=tse_pred,
144
+ timesteps=t,
145
+ mixture=mixture,
146
+ reference=reference,
147
+ x_len=lengths,
148
+ ref_len=reference_lengths
149
+ )
150
+ tse_pred = scheduler.step(model_output=model_output, timestep=t, sample=tse_pred,
151
+ eta=eta, generator=generator).prev_sample
152
+
153
+ for t in scheduler.timesteps:
154
+ tsr_pred = scheduler.scale_model_input(tsr_pred, t)
155
+ model_output, _ = tsr_model(
156
+ x=tsr_pred,
157
+ timesteps=t,
158
+ mixture=mixture,
159
+ reference=tse_pred,
160
+ x_len=lengths,
161
+ )
162
+ tsr_pred = scheduler.step(model_output=model_output, timestep=t, sample=tsr_pred,
163
+ eta=eta, generator=generator).prev_sample
164
+
165
+ tse_pred = autoencoder(embedding=tse_pred.transpose(2,1), std=std).squeeze(1)
166
+ tsr_pred = autoencoder(embedding=tsr_pred.transpose(2,1), std=std).squeeze(1)
167
+
168
+ return tse_pred, tsr_pred
169
+
170
+ @spaces.GPU
171
+ def tse(test_wav, enroll_wav):
172
+ print("Start Extraction...")
173
+ start_time = time.time()
174
+ mixture, _ = librosa.load(test_wav, sr=16000)
175
+ reference, _ = librosa.load(enroll_wav, sr=16000)
176
+ reference_wav = reference
177
+ reference = torch.tensor(reference).unsqueeze(0).to(device)
178
+ with torch.no_grad():
179
+ # compressor
180
+ reference, _ = autoencoder(audio=reference.unsqueeze(1))
181
+ reference_lengths = torch.LongTensor([reference.shape[-1]]).to(device)
182
+ mixture_input = torch.tensor(mixture).unsqueeze(0).to(device)
183
+ mixture_wav = mixture_input
184
+ mixture_input, std = autoencoder(audio=mixture_input.unsqueeze(1))
185
+ lengths = torch.LongTensor([mixture_input.shape[-1]]).to(device)
186
+ # extractor
187
+ tse_pred, tsr_pred = sample_diffusion(tse_model, tsr_model, autoencoder, std, noise_scheduler, device, mixture_input.transpose(2,1), reference.transpose(2,1), lengths, reference_lengths, ddim_steps=args.num_infer_steps, eta=args.eta, seed=args.random_seed)
188
+ ecapatdnn_embedding1 = ecapatdnn_model.encode_batch(tse_pred.squeeze()).squeeze()
189
+ ecapatdnn_embedding2 = ecapatdnn_model.encode_batch(tsr_pred.squeeze()).squeeze()
190
+ ecapatdnn_embedding3 = ecapatdnn_model.encode_batch(torch.tensor(reference_wav)).squeeze()
191
+ sim1 = cosine_sim(ecapatdnn_embedding1, ecapatdnn_embedding3).item()
192
+ sim2 = cosine_sim(ecapatdnn_embedding2, ecapatdnn_embedding3).item()
193
+ pred = tse_pred if sim1 > sim2 else tsr_pred
194
+ # corrector
195
+ min_leng = min(pred.shape[-1], mixture_wav.shape[-1])
196
+ x = pred[...,:min_leng]
197
+ m = mixture_wav[...,:min_leng]
198
+ norm_factor = m.abs().max()
199
+ x = x / norm_factor
200
+ m = m / norm_factor
201
+ X = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(x.cuda())), 0)
202
+ X = pad_spec(X)
203
+ M = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(m.cuda())), 0)
204
+ M = pad_spec(M)
205
+ timesteps = torch.linspace(0.5, 0.03, 1, device=M.device)
206
+ std = geco_model.sde._std(0.5*torch.ones((M.shape[0],), device=M.device))
207
+ z = torch.randn_like(M)
208
+ X_t = M + z * std[:, None, None, None]
209
+
210
+ for idx in range(len(timesteps)):
211
+ t = timesteps[idx]
212
+ if idx != len(timesteps) - 1:
213
+ dt = t - timesteps[idx+1]
214
+ else:
215
+ dt = timesteps[-1]
216
+ with torch.no_grad():
217
+ f, g = geco_model.sde.sde(X_t, t, M)
218
+ vec_t = torch.ones(M.shape[0], device=M.device) * t
219
+ mean_x_tm1 = X_t - (f - g**2*geco_model.forward(X_t, vec_t, M, X, vec_t[:,None,None,None]))*dt
220
+ if idx == len(timesteps) - 1:
221
+ X_t = mean_x_tm1
222
+ break
223
+ z = torch.randn_like(X)
224
+ X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
225
+
226
+ sample = X_t
227
+ sample = sample.squeeze()
228
+ x_hat = geco_model.to_audio(sample.squeeze(), min_leng)
229
+ x_hat = x_hat * norm_factor / x_hat.abs().max()
230
+ x_hat = x_hat.detach().cpu()
231
+
232
+ end_time = time.time()
233
+ audio_len = x_hat.shape[-1] / 16000
234
+ rtf = (end_time-start_time)/audio_len
235
+ print(f"RTF: {rtf:.4f}")
236
+ return (16000, x_hat)
237
+
238
+
239
+ @spaces.GPU
240
+ def process_audio(test_wav, enroll_wav):
241
+ result = tse(test_wav, enroll_wav)
242
+ return result
243
+
244
+
245
+ # List of demo audio files
246
+ demo_audio_files = [
247
+ ("Demo Pair 1", "demo/test1.wav", "demo/test1_enroll.wav"),
248
+ ("Demo Pair 2", "demo/test2.wav", "demo/test2_enroll.wav")
249
+ ]
250
+
251
+ def update_audio_input(choice):
252
+ return choice
253
+
254
+ # CSS styling (optional)
255
+ css = """
256
+ #col-container {
257
+ margin: 0 auto;
258
+ max-width: 1280px;
259
+ }
260
+ """
261
+
262
+ # Gradio Blocks layout
263
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
264
+ with gr.Column(elem_id="col-container"):
265
+ gr.Markdown("""
266
+ # SoloSpeech: Enhancing Intelligibility and Quality in Target Speech Extraction through a Cascaded Generative Pipeline
267
+ Extract the target voice from mixture speech given an enrollment speech.
268
+
269
+ Learn more about 🎸**SoloSpeech** on the [SoloSpeech Repo](https://github.com/WangHelin1997/SoloSpeech/).
270
+ """)
271
+
272
+ with gr.Tab("Target Speech Extraction"):
273
+ with gr.Row():
274
+ mixture_input = gr.Audio(label="Upload Mixture Audio", type="filepath", value="demo/test1.wav")
275
+ enroll_input = gr.Audio(label="Upload Enrollment Audio", type="filepath", value="demo/enroll1.wav")
276
+
277
+ with gr.Row():
278
+ demo_selector = gr.Dropdown(
279
+ label="Select Demo Pair",
280
+ choices=[name for name, _, _ in demo_audio_files],
281
+ value="Demo Pair 1"
282
+ )
283
+ extract_button = gr.Button("Extract", scale=1)
284
+
285
+ with gr.Row():
286
+ result = gr.Audio(label="Extracted Speech", type="numpy")
287
+
288
+ # Update audio inputs when selecting from dropdown
289
+ def update_audio_inputs(choice):
290
+ for name, mixture_path, enroll_path in demo_audio_files:
291
+ if name == choice:
292
+ return mixture_path, enroll_path
293
+ return None, None
294
+
295
+ demo_selector.change(
296
+ fn=update_audio_inputs,
297
+ inputs=demo_selector,
298
+ outputs=[mixture_input, enroll_input]
299
+ )
300
+
301
+ extract_button.click(
302
+ fn=process_audio,
303
+ inputs=[mixture_input, enroll_input],
304
+ outputs=[result]
305
+ )
306
+
307
+ # Launch the Gradio demo
308
+ demo.launch()