Spaces:
Running
Running
| # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import json | |
| import re | |
| import sys | |
| import torch | |
| from common.text.symbols import get_symbols, get_pad_idx | |
| from common.utils import DefaultAttrDict, AttrDict | |
| from fastpitch.model import FastPitch | |
| from fastpitch.model_jit import FastPitchJIT | |
| from hifigan.models import Generator | |
| try: | |
| from waveglow.model import WaveGlow | |
| from waveglow import model as glow | |
| from waveglow.denoiser import Denoiser | |
| sys.modules['glow'] = glow | |
| except ImportError: | |
| print("WARNING: Couldn't import WaveGlow") | |
| def parse_model_args(model_name, parser, add_help=False): | |
| if model_name == 'FastPitch': | |
| from fastpitch import arg_parser | |
| return arg_parser.parse_fastpitch_args(parser, add_help) | |
| elif model_name == 'HiFi-GAN': | |
| from hifigan import arg_parser | |
| return arg_parser.parse_hifigan_args(parser, add_help) | |
| elif model_name == 'WaveGlow': | |
| from waveglow.arg_parser import parse_waveglow_args | |
| return parse_waveglow_args(parser, add_help) | |
| else: | |
| raise NotImplementedError(model_name) | |
| def get_model(model_name, model_config, device, bn_uniform_init=False, | |
| forward_is_infer=False, jitable=False): | |
| """Chooses a model based on name""" | |
| del bn_uniform_init # unused (old name: uniform_initialize_bn_weight) | |
| if model_name == 'FastPitch': | |
| if jitable: | |
| model = FastPitchJIT(**model_config) | |
| else: | |
| model = FastPitch(**model_config) | |
| elif model_name == 'HiFi-GAN': | |
| model = Generator(model_config) | |
| elif model_name == 'WaveGlow': | |
| model = WaveGlow(**model_config) | |
| else: | |
| raise NotImplementedError(model_name) | |
| if forward_is_infer and hasattr(model, 'infer'): | |
| model.forward = model.infer | |
| return model.to(device) | |
| def get_model_config(model_name, args, ckpt_config=None): | |
| """ Get config needed to instantiate the model """ | |
| # Mark keys missing in `args` with an object (None is ambiguous) | |
| _missing = object() | |
| args = DefaultAttrDict(lambda: _missing, vars(args)) | |
| # `ckpt_config` is loaded from the checkpoint and has the priority | |
| # `model_config` is based on args and fills empty slots in `ckpt_config` | |
| if model_name == 'FastPitch': | |
| print(get_symbols(args.symbol_set)) ############################ | |
| model_config = dict( | |
| # io | |
| n_mel_channels=args.n_mel_channels, | |
| # symbols | |
| n_symbols=(len(get_symbols(args.symbol_set)) | |
| if args.symbol_set is not _missing else _missing), | |
| padding_idx=(get_pad_idx(args.symbol_set) | |
| if args.symbol_set is not _missing else _missing), | |
| symbols_embedding_dim=args.symbols_embedding_dim, | |
| # input FFT | |
| in_fft_n_layers=args.in_fft_n_layers, | |
| in_fft_n_heads=args.in_fft_n_heads, | |
| in_fft_d_head=args.in_fft_d_head, | |
| in_fft_conv1d_kernel_size=args.in_fft_conv1d_kernel_size, | |
| in_fft_conv1d_filter_size=args.in_fft_conv1d_filter_size, | |
| in_fft_output_size=args.in_fft_output_size, | |
| p_in_fft_dropout=args.p_in_fft_dropout, | |
| p_in_fft_dropatt=args.p_in_fft_dropatt, | |
| p_in_fft_dropemb=args.p_in_fft_dropemb, | |
| # output FFT | |
| out_fft_n_layers=args.out_fft_n_layers, | |
| out_fft_n_heads=args.out_fft_n_heads, | |
| out_fft_d_head=args.out_fft_d_head, | |
| out_fft_conv1d_kernel_size=args.out_fft_conv1d_kernel_size, | |
| out_fft_conv1d_filter_size=args.out_fft_conv1d_filter_size, | |
| out_fft_output_size=args.out_fft_output_size, | |
| p_out_fft_dropout=args.p_out_fft_dropout, | |
| p_out_fft_dropatt=args.p_out_fft_dropatt, | |
| p_out_fft_dropemb=args.p_out_fft_dropemb, | |
| # duration predictor | |
| dur_predictor_kernel_size=args.dur_predictor_kernel_size, | |
| dur_predictor_filter_size=args.dur_predictor_filter_size, | |
| p_dur_predictor_dropout=args.p_dur_predictor_dropout, | |
| dur_predictor_n_layers=args.dur_predictor_n_layers, | |
| # pitch predictor | |
| pitch_predictor_kernel_size=args.pitch_predictor_kernel_size, | |
| pitch_predictor_filter_size=args.pitch_predictor_filter_size, | |
| p_pitch_predictor_dropout=args.p_pitch_predictor_dropout, | |
| pitch_predictor_n_layers=args.pitch_predictor_n_layers, | |
| # pitch conditioning | |
| pitch_embedding_kernel_size=args.pitch_embedding_kernel_size, | |
| # speakers parameters | |
| n_speakers=args.n_speakers, | |
| speaker_emb_weight=args.speaker_emb_weight, | |
| n_languages=args.n_languages, | |
| # energy predictor | |
| energy_predictor_kernel_size=args.energy_predictor_kernel_size, | |
| energy_predictor_filter_size=args.energy_predictor_filter_size, | |
| p_energy_predictor_dropout=args.p_energy_predictor_dropout, | |
| energy_predictor_n_layers=args.energy_predictor_n_layers, | |
| # energy conditioning | |
| energy_conditioning=args.energy_conditioning, | |
| energy_embedding_kernel_size=args.energy_embedding_kernel_size, | |
| ) | |
| elif model_name == 'HiFi-GAN': | |
| if args.hifigan_config is not None: | |
| assert ckpt_config is None, ( | |
| "Supplied --hifigan-config, but the checkpoint has a config. " | |
| "Drop the flag or remove the config from the checkpoint file.") | |
| print(f'HiFi-GAN: Reading model config from {args.hifigan_config}') | |
| with open(args.hifigan_config) as f: | |
| args = AttrDict(json.load(f)) | |
| model_config = dict( | |
| # generator architecture | |
| upsample_rates=args.upsample_rates, | |
| upsample_kernel_sizes=args.upsample_kernel_sizes, | |
| upsample_initial_channel=args.upsample_initial_channel, | |
| resblock=args.resblock, | |
| resblock_kernel_sizes=args.resblock_kernel_sizes, | |
| resblock_dilation_sizes=args.resblock_dilation_sizes, | |
| ) | |
| elif model_name == 'WaveGlow': | |
| model_config = dict( | |
| n_mel_channels=args.n_mel_channels, | |
| n_flows=args.flows, | |
| n_group=args.groups, | |
| n_early_every=args.early_every, | |
| n_early_size=args.early_size, | |
| WN_config=dict( | |
| n_layers=args.wn_layers, | |
| kernel_size=args.wn_kernel_size, | |
| n_channels=args.wn_channels | |
| ) | |
| ) | |
| else: | |
| raise NotImplementedError(model_name) | |
| # Start with ckpt_config, and fill missing keys from model_config | |
| final_config = {} if ckpt_config is None else ckpt_config.copy() | |
| missing_keys = set(model_config.keys()) - set(final_config.keys()) | |
| final_config.update({k: model_config[k] for k in missing_keys}) | |
| # If there was a ckpt_config, it should have had all args | |
| if ckpt_config is not None and len(missing_keys) > 0: | |
| print(f'WARNING: Keys {missing_keys} missing from the loaded config; ' | |
| 'using args instead.') | |
| # NOTE: useful to debug the assertion error | |
| #for k, v in final_config.items(): | |
| # if v is _missing: | |
| # print(k) | |
| assert all(v is not _missing for v in final_config.values()) ########################################## | |
| return final_config | |
| def get_model_train_setup(model_name, args): | |
| """ Dump train setup for documentation purposes """ | |
| if model_name == 'FastPitch': | |
| return dict() | |
| elif model_name == 'HiFi-GAN': | |
| return dict( | |
| # audio | |
| segment_size=args.segment_size, | |
| filter_length=args.filter_length, | |
| num_mels=args.num_mels, | |
| hop_length=args.hop_length, | |
| win_length=args.win_length, | |
| sampling_rate=args.sampling_rate, | |
| mel_fmin=args.mel_fmin, | |
| mel_fmax=args.mel_fmax, | |
| mel_fmax_loss=args.mel_fmax_loss, | |
| max_wav_value=args.max_wav_value, | |
| # other | |
| seed=args.seed, | |
| # optimization | |
| base_lr=args.learning_rate, | |
| lr_decay=args.lr_decay, | |
| epochs_all=args.epochs, | |
| ) | |
| elif model_name == 'WaveGlow': | |
| return dict() | |
| else: | |
| raise NotImplementedError(model_name) | |
| def load_model_from_ckpt(checkpoint_data, model, key='state_dict'): | |
| if key is None: | |
| return checkpoint_data['model'], None | |
| sd = checkpoint_data[key] | |
| sd = {re.sub('^module\.', '', k): v for k, v in sd.items()} | |
| status = model.load_state_dict(sd, strict=False) | |
| return model, status | |
| def load_and_setup_model(model_name, parser, checkpoint, amp, device, | |
| unk_args=[], forward_is_infer=False, jitable=False): | |
| if checkpoint is not None: | |
| #ckpt_data = torch.load(checkpoint) | |
| ckpt_data = torch.load(checkpoint, map_location=device) | |
| print(f'{model_name}: Loading {checkpoint}...') | |
| ckpt_config = ckpt_data.get('config') | |
| if ckpt_config is None: | |
| print(f'{model_name}: No model config in the checkpoint; using args.') | |
| else: | |
| print(f'{model_name}: Found model config saved in the checkpoint.') | |
| else: | |
| ckpt_config = None | |
| ckpt_data = {} | |
| model_parser = parse_model_args(model_name, parser, add_help=False) | |
| model_args, model_unk_args = model_parser.parse_known_args() | |
| unk_args[:] = list(set(unk_args) & set(model_unk_args)) | |
| model_config = get_model_config(model_name, model_args, ckpt_config) | |
| model = get_model(model_name, model_config, device, | |
| forward_is_infer=forward_is_infer, | |
| jitable=jitable) | |
| if checkpoint is not None: | |
| key = 'generator' if model_name == 'HiFi-GAN' else 'state_dict' | |
| model, status = load_model_from_ckpt(ckpt_data, model, key) | |
| missing = [] if status is None else status.missing_keys | |
| unexpected = [] if status is None else status.unexpected_keys | |
| # Attention is only used during training, we won't miss it | |
| if model_name == 'FastPitch': | |
| missing = [k for k in missing if not k.startswith('attention.')] | |
| unexpected = [k for k in unexpected if not k.startswith('attention.')] | |
| assert len(missing) == 0 and len(unexpected) == 0, ( | |
| f'Mismatched keys when loading parameters. Missing: {missing}, ' | |
| f'unexpected: {unexpected}.') | |
| if model_name == "WaveGlow": | |
| for k, m in model.named_modules(): | |
| m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability | |
| model = model.remove_weightnorm(model) | |
| elif model_name == 'HiFi-GAN': | |
| assert model_args.hifigan_config is not None or ckpt_config is not None, ( | |
| 'Use a HiFi-GAN checkpoint from NVIDIA DeepLearningExamples with ' | |
| 'saved config or supply --hifigan-config <json_file>.') | |
| model.remove_weight_norm() | |
| if amp: | |
| model.half() | |
| model.eval() | |
| return model.to(device), model_config, ckpt_data.get('train_setup', {}) | |
| def load_and_setup_ts_model(model_name, checkpoint, amp, device=None): | |
| print(f'{model_name}: Loading TorchScript checkpoint {checkpoint}...') | |
| model = torch.jit.load(checkpoint).eval() | |
| if device is not None: | |
| model = model.to(device) | |
| if amp: | |
| model.half() | |
| elif next(model.parameters()).dtype == torch.float16: | |
| raise ValueError('Trying to load FP32 model,' | |
| 'TS checkpoint is in FP16 precision.') | |
| return model | |
| def convert_ts_to_trt(model_name, ts_model, parser, amp, unk_args=[]): | |
| trt_parser = _parse_trt_compilation_args(model_name, parser, add_help=False) | |
| trt_args, trt_unk_args = trt_parser.parse_known_args() | |
| unk_args[:] = list(set(unk_args) & set(trt_unk_args)) | |
| if model_name == 'HiFi-GAN': | |
| return _convert_ts_to_trt_hifigan( | |
| ts_model, amp, trt_args.trt_min_opt_max_batch, | |
| trt_args.trt_min_opt_max_hifigan_length) | |
| else: | |
| raise NotImplementedError | |
| def _parse_trt_compilation_args(model_name, parent, add_help=False): | |
| """ | |
| Parse model and inference specific commandline arguments. | |
| """ | |
| parser = argparse.ArgumentParser(parents=[parent], add_help=add_help, | |
| allow_abbrev=False) | |
| trt = parser.add_argument_group(f'{model_name} Torch-TensorRT compilation parameters') | |
| trt.add_argument('--trt-min-opt-max-batch', nargs=3, type=int, | |
| default=(1, 8, 16), | |
| help='Torch-TensorRT min, optimal and max batch size') | |
| if model_name == 'HiFi-GAN': | |
| trt.add_argument('--trt-min-opt-max-hifigan-length', nargs=3, type=int, | |
| default=(100, 800, 1200), | |
| help='Torch-TensorRT min, optimal and max audio length (in frames)') | |
| return parser | |
| def _convert_ts_to_trt_hifigan(ts_model, amp, trt_min_opt_max_batch, | |
| trt_min_opt_max_hifigan_length, num_mels=80): | |
| import torch_tensorrt | |
| trt_dtype = torch.half if amp else torch.float | |
| print(f'Torch TensorRT: compiling HiFi-GAN for dtype {trt_dtype}.') | |
| min_shp, opt_shp, max_shp = zip(trt_min_opt_max_batch, | |
| (num_mels,) * 3, | |
| trt_min_opt_max_hifigan_length) | |
| compile_settings = { | |
| "inputs": [torch_tensorrt.Input( | |
| min_shape=min_shp, | |
| opt_shape=opt_shp, | |
| max_shape=max_shp, | |
| dtype=trt_dtype, | |
| )], | |
| "enabled_precisions": {trt_dtype}, | |
| "require_full_compilation": True, | |
| } | |
| trt_model = torch_tensorrt.compile(ts_model, **compile_settings) | |
| print('Torch TensorRT: compilation successful.') | |
| return trt_model | |