Spaces:
Configuration error
Configuration error
Zhikang Niu
commited on
Update inference-cli.py add load vocos from local path
Browse files- inference-cli.py +35 -21
inference-cli.py
CHANGED
@@ -1,26 +1,24 @@
|
|
|
|
|
|
1 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import torchaudio
|
4 |
-
import
|
5 |
-
import
|
6 |
from einops import rearrange
|
7 |
-
from vocos import Vocos
|
8 |
from pydub import AudioSegment, silence
|
9 |
-
from model import CFM, UNetT, DiT, MMDiT
|
10 |
-
from cached_path import cached_path
|
11 |
-
from model.utils import (
|
12 |
-
load_checkpoint,
|
13 |
-
get_tokenizer,
|
14 |
-
convert_char_to_pinyin,
|
15 |
-
save_spectrogram,
|
16 |
-
)
|
17 |
from transformers import pipeline
|
18 |
-
|
19 |
-
|
20 |
-
import
|
21 |
-
import
|
22 |
-
|
23 |
-
import codecs
|
24 |
|
25 |
parser = argparse.ArgumentParser(
|
26 |
prog="python3 inference-cli.py",
|
@@ -73,6 +71,11 @@ parser.add_argument(
|
|
73 |
"--remove_silence",
|
74 |
help="Remove silence.",
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
76 |
args = parser.parse_args()
|
77 |
|
78 |
config = tomli.load(open(args.config, "rb"))
|
@@ -88,6 +91,7 @@ model = args.model if args.model else config["model"]
|
|
88 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
89 |
wave_path = Path(output_dir)/"out.wav"
|
90 |
spectrogram_path = Path(output_dir)/"out.png"
|
|
|
91 |
|
92 |
SPLIT_WORDS = [
|
93 |
"but", "however", "nevertheless", "yet", "still",
|
@@ -105,7 +109,16 @@ device = (
|
|
105 |
if torch.cuda.is_available()
|
106 |
else "mps" if torch.backends.mps.is_available() else "cpu"
|
107 |
)
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
print(f"Using {device} device")
|
111 |
|
@@ -124,8 +137,9 @@ speed = 1.0
|
|
124 |
fix_duration = None
|
125 |
|
126 |
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
127 |
-
ckpt_path =
|
128 |
-
|
|
|
129 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
130 |
model = CFM(
|
131 |
transformer=model_cls(
|
@@ -385,4 +399,4 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
385 |
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
386 |
|
387 |
|
388 |
-
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|
|
|
1 |
+
import argparse
|
2 |
+
import codecs
|
3 |
import re
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import soundfile as sf
|
9 |
+
import tomli
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
+
import tqdm
|
13 |
+
from cached_path import cached_path
|
14 |
from einops import rearrange
|
|
|
15 |
from pydub import AudioSegment, silence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from transformers import pipeline
|
17 |
+
from vocos import Vocos
|
18 |
+
|
19 |
+
from model import CFM, DiT, MMDiT, UNetT
|
20 |
+
from model.utils import (convert_char_to_pinyin, get_tokenizer,
|
21 |
+
load_checkpoint, save_spectrogram)
|
|
|
22 |
|
23 |
parser = argparse.ArgumentParser(
|
24 |
prog="python3 inference-cli.py",
|
|
|
71 |
"--remove_silence",
|
72 |
help="Remove silence.",
|
73 |
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--load_vocoder_from_local",
|
76 |
+
action="store_true",
|
77 |
+
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
|
78 |
+
)
|
79 |
args = parser.parse_args()
|
80 |
|
81 |
config = tomli.load(open(args.config, "rb"))
|
|
|
91 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
92 |
wave_path = Path(output_dir)/"out.wav"
|
93 |
spectrogram_path = Path(output_dir)/"out.png"
|
94 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
95 |
|
96 |
SPLIT_WORDS = [
|
97 |
"but", "however", "nevertheless", "yet", "still",
|
|
|
109 |
if torch.cuda.is_available()
|
110 |
else "mps" if torch.backends.mps.is_available() else "cpu"
|
111 |
)
|
112 |
+
|
113 |
+
if args.load_vocoder_from_local:
|
114 |
+
print(f"Load vocos from local path {vocos_local_path}")
|
115 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
116 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
117 |
+
vocos.load_state_dict(state_dict)
|
118 |
+
vocos.eval()
|
119 |
+
else:
|
120 |
+
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
|
121 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
122 |
|
123 |
print(f"Using {device} device")
|
124 |
|
|
|
137 |
fix_duration = None
|
138 |
|
139 |
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
140 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
141 |
+
if not Path(ckpt_path).exists():
|
142 |
+
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
143 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
144 |
model = CFM(
|
145 |
transformer=model_cls(
|
|
|
399 |
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
400 |
|
401 |
|
402 |
+
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|