Zhikang Niu commited on
Commit
830a2fe
·
unverified ·
1 Parent(s): 2d26fba

Update inference-cli.py add load vocos from local path

Browse files
Files changed (1) hide show
  1. inference-cli.py +35 -21
inference-cli.py CHANGED
@@ -1,26 +1,24 @@
 
 
1
  import re
 
 
 
 
 
 
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
- import tempfile
6
  from einops import rearrange
7
- from vocos import Vocos
8
  from pydub import AudioSegment, silence
9
- from model import CFM, UNetT, DiT, MMDiT
10
- from cached_path import cached_path
11
- from model.utils import (
12
- load_checkpoint,
13
- get_tokenizer,
14
- convert_char_to_pinyin,
15
- save_spectrogram,
16
- )
17
  from transformers import pipeline
18
- import soundfile as sf
19
- import tomli
20
- import argparse
21
- import tqdm
22
- from pathlib import Path
23
- import codecs
24
 
25
  parser = argparse.ArgumentParser(
26
  prog="python3 inference-cli.py",
@@ -73,6 +71,11 @@ parser.add_argument(
73
  "--remove_silence",
74
  help="Remove silence.",
75
  )
 
 
 
 
 
76
  args = parser.parse_args()
77
 
78
  config = tomli.load(open(args.config, "rb"))
@@ -88,6 +91,7 @@ model = args.model if args.model else config["model"]
88
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
89
  wave_path = Path(output_dir)/"out.wav"
90
  spectrogram_path = Path(output_dir)/"out.png"
 
91
 
92
  SPLIT_WORDS = [
93
  "but", "however", "nevertheless", "yet", "still",
@@ -105,7 +109,16 @@ device = (
105
  if torch.cuda.is_available()
106
  else "mps" if torch.backends.mps.is_available() else "cpu"
107
  )
108
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
 
 
 
 
 
 
 
 
109
 
110
  print(f"Using {device} device")
111
 
@@ -124,8 +137,9 @@ speed = 1.0
124
  fix_duration = None
125
 
126
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
127
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
128
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
 
129
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
130
  model = CFM(
131
  transformer=model_cls(
@@ -385,4 +399,4 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
385
  return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
386
 
387
 
388
- infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
 
1
+ import argparse
2
+ import codecs
3
  import re
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import tomli
10
  import torch
11
  import torchaudio
12
+ import tqdm
13
+ from cached_path import cached_path
14
  from einops import rearrange
 
15
  from pydub import AudioSegment, silence
 
 
 
 
 
 
 
 
16
  from transformers import pipeline
17
+ from vocos import Vocos
18
+
19
+ from model import CFM, DiT, MMDiT, UNetT
20
+ from model.utils import (convert_char_to_pinyin, get_tokenizer,
21
+ load_checkpoint, save_spectrogram)
 
22
 
23
  parser = argparse.ArgumentParser(
24
  prog="python3 inference-cli.py",
 
71
  "--remove_silence",
72
  help="Remove silence.",
73
  )
74
+ parser.add_argument(
75
+ "--load_vocoder_from_local",
76
+ action="store_true",
77
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
78
+ )
79
  args = parser.parse_args()
80
 
81
  config = tomli.load(open(args.config, "rb"))
 
91
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92
  wave_path = Path(output_dir)/"out.wav"
93
  spectrogram_path = Path(output_dir)/"out.png"
94
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
95
 
96
  SPLIT_WORDS = [
97
  "but", "however", "nevertheless", "yet", "still",
 
109
  if torch.cuda.is_available()
110
  else "mps" if torch.backends.mps.is_available() else "cpu"
111
  )
112
+
113
+ if args.load_vocoder_from_local:
114
+ print(f"Load vocos from local path {vocos_local_path}")
115
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
116
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
117
+ vocos.load_state_dict(state_dict)
118
+ vocos.eval()
119
+ else:
120
+ print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
121
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
122
 
123
  print(f"Using {device} device")
124
 
 
137
  fix_duration = None
138
 
139
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
140
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
141
+ if not Path(ckpt_path).exists():
142
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
143
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
144
  model = CFM(
145
  transformer=model_cls(
 
399
  return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
400
 
401
 
402
+ infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))