Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from tqdm import tqdm | |
| from copy import deepcopy | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import random | |
| random.seed(0) | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| from scipy.io.wavfile import write as wavwrite | |
| from util import print_size, sampling | |
| from network import CleanUNet | |
| import torchaudio | |
| def load_simple(filename): | |
| audio, _ = torchaudio.load(filename) | |
| return audio | |
| CONFIG = "configs/DNS-large-full.json" | |
| CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" | |
| # Parse configs. Globals nicer in this case | |
| with open(CONFIG) as f: | |
| data = f.read() | |
| config = json.loads(data) | |
| gen_config = config["gen_config"] | |
| global network_config | |
| network_config = config["network_config"] # to define wavenet | |
| global train_config | |
| train_config = config["train_config"] # train config | |
| global trainset_config | |
| trainset_config = config["trainset_config"] # to read trainset configurations | |
| def denoise(files, ckpt_path): | |
| """ | |
| Denoise audio | |
| Parameters: | |
| output_directory (str): save generated speeches to this path | |
| ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
| automitically selects the maximum iteration if 'max' is selected | |
| subset (str): training, testing, validation | |
| dump (bool): whether save enhanced (denoised) audio | |
| """ | |
| # setup local experiment path | |
| exp_path = train_config["exp_path"] | |
| print('exp_path:', exp_path) | |
| # load data | |
| loader_config = deepcopy(trainset_config) | |
| loader_config["crop_length_sec"] = 0 | |
| # predefine model | |
| net = CleanUNet(**network_config) | |
| print_size(net) | |
| # load checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| net.load_state_dict(checkpoint['model_state_dict']) | |
| net.eval() | |
| # inference | |
| batch_size = 1000000 | |
| for file_path in tqdm(files): | |
| file_name = os.path.basename(file_path) | |
| file_dir = os.path.dirname(file_name) | |
| new_file_name = file_name + "_denoised.wav" | |
| noisy_audio = load_simple(file_path) | |
| LENGTH = len(noisy_audio[0].squeeze()) | |
| noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1) | |
| all_audio = [] | |
| for batch in tqdm(noisy_audio): | |
| with torch.no_grad(): | |
| generated_audio = sampling(net, batch) | |
| generated_audio = generated_audio.cpu().numpy().squeeze() | |
| all_audio.append(generated_audio) | |
| all_audio = np.concatenate(all_audio, axis=0) | |
| save_file = os.path.join(file_dir, new_file_name) | |
| print("saved to:", save_file) | |
| wavwrite(save_file, 32000, all_audio.squeeze()) | |
| audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') | |
| inputs = [audio, CHECKPOINT] | |
| outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') | |
| title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" | |
| gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |