Spaces:
Sleeping
Sleeping
| import sys,os | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| import torch | |
| import argparse | |
| from omegaconf import OmegaConf | |
| from vits.models import SynthesizerInfer | |
| def load_model(checkpoint_path, model): | |
| assert os.path.isfile(checkpoint_path) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
| saved_state_dict = checkpoint_dict["model_g"] | |
| if hasattr(model, "module"): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| try: | |
| new_state_dict[k] = saved_state_dict[k] | |
| except: | |
| new_state_dict[k] = v | |
| if hasattr(model, "module"): | |
| model.module.load_state_dict(new_state_dict) | |
| else: | |
| model.load_state_dict(new_state_dict) | |
| return model | |
| def save_pretrain(checkpoint_path, save_path): | |
| assert os.path.isfile(checkpoint_path) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
| torch.save({ | |
| 'model_g': checkpoint_dict['model_g'], | |
| 'model_d': checkpoint_dict['model_d'], | |
| }, save_path) | |
| def save_model(model, checkpoint_path): | |
| if hasattr(model, 'module'): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| torch.save({'model_g': state_dict}, checkpoint_path) | |
| def main(args): | |
| hp = OmegaConf.load(args.config) | |
| model = SynthesizerInfer( | |
| hp.data.filter_length // 2 + 1, | |
| hp.data.segment_size // hp.data.hop_length, | |
| hp) | |
| # save_pretrain(args.checkpoint_path, "sovits5.0.pretrain.pth") | |
| load_model(args.checkpoint_path, model) | |
| save_model(model, "sovits5.0.pth") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, required=True, | |
| help="yaml file for config. will use hp_str from checkpoint if not given.") | |
| parser.add_argument('-p', '--checkpoint_path', type=str, required=True, | |
| help="path of checkpoint pt file for evaluation") | |
| args = parser.parse_args() | |
| main(args) | |