Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import json | |
| from tqdm import tqdm | |
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| import random | |
| random.seed(0) | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| from scipy.io.wavfile import write as wavwrite | |
| from dataset import load_CleanNoisyPairDataset | |
| from util import find_max_epoch, print_size, sampling | |
| from network import CleanUNet | |
| def denoise(output_directory, ckpt_iter, subset, dump=False): | |
| """ | |
| Denoise audio | |
| Parameters: | |
| output_directory (str): save generated speeches to this path | |
| ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
| automitically selects the maximum iteration if 'max' is selected | |
| subset (str): training, testing, validation | |
| dump (bool): whether save enhanced (denoised) audio | |
| """ | |
| # setup local experiment path | |
| exp_path = train_config["exp_path"] | |
| print('exp_path:', exp_path) | |
| # load data | |
| loader_config = deepcopy(trainset_config) | |
| loader_config["crop_length_sec"] = 0 | |
| dataloader = load_CleanNoisyPairDataset( | |
| **loader_config, | |
| subset=subset, | |
| batch_size=1, | |
| num_gpus=1 | |
| ) | |
| # predefine model | |
| net = CleanUNet(**network_config).cuda() | |
| print_size(net) | |
| # load checkpoint | |
| ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint') | |
| if ckpt_iter == 'max': | |
| ckpt_iter = find_max_epoch(ckpt_directory) | |
| if ckpt_iter != 'pretrained': | |
| ckpt_iter = int(ckpt_iter) | |
| model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter)) | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| net.load_state_dict(checkpoint['model_state_dict']) | |
| net.eval() | |
| # get output directory ready | |
| if ckpt_iter == "pretrained": | |
| speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter) | |
| else: | |
| speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000)) | |
| if dump and not os.path.isdir(speech_directory): | |
| os.makedirs(speech_directory) | |
| os.chmod(speech_directory, 0o775) | |
| print("speech_directory: ", speech_directory, flush=True) | |
| # inference | |
| all_generated_audio = [] | |
| all_clean_audio = [] | |
| sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:]) | |
| for clean_audio, noisy_audio, fileid in tqdm(dataloader): | |
| filename = sortkey(fileid[0][0]) | |
| noisy_audio = noisy_audio.cuda() | |
| LENGTH = len(noisy_audio[0].squeeze()) | |
| generated_audio = sampling(net, noisy_audio) | |
| if dump: | |
| wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)), | |
| trainset_config["sample_rate"], | |
| generated_audio[0].squeeze().cpu().numpy()) | |
| else: | |
| all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy()) | |
| all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy()) | |
| return all_clean_audio, all_generated_audio | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, default='config.json', | |
| help='JSON file for configuration') | |
| parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max', | |
| help='Which checkpoint to use; assign a number or "max" or "pretrained"') | |
| parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'], | |
| default='testing', help='subset for denoising') | |
| args = parser.parse_args() | |
| # Parse configs. Globals nicer in this case | |
| with open(args.config) as f: | |
| data = f.read() | |
| config = json.loads(data) | |
| gen_config = config["gen_config"] | |
| global network_config | |
| network_config = config["network_config"] # to define wavenet | |
| global train_config | |
| train_config = config["train_config"] # train config | |
| global trainset_config | |
| trainset_config = config["trainset_config"] # to read trainset configurations | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| if args.subset == "testing": | |
| denoise(gen_config["output_directory"], | |
| subset=args.subset, | |
| ckpt_iter=args.ckpt_iter, | |
| dump=True) |