Spaces:
Configuration error
Configuration error
refactor: del global params and set vocos as default vocoder, add dtype check
Browse files- README.md +1 -1
- src/f5_tts/eval/eval_infer_batch.py +5 -4
- src/f5_tts/infer/infer_cli.py +2 -2
- src/f5_tts/infer/speech_edit.py +3 -3
- src/f5_tts/infer/utils_infer.py +17 -7
README.md
CHANGED
@@ -45,7 +45,7 @@ git clone https://github.com/SWivid/F5-TTS.git
|
|
45 |
cd F5-TTS
|
46 |
pip install -e .
|
47 |
|
48 |
-
# Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
|
49 |
# git submodule update --init --recursive
|
50 |
# pip install -e .
|
51 |
```
|
|
|
45 |
cd F5-TTS
|
46 |
pip install -e .
|
47 |
|
48 |
+
# Init submodule (optional, if you want to change the vocoder from vocos to bigvgan)
|
49 |
# git submodule update --init --recursive
|
50 |
# pip install -e .
|
51 |
```
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -32,7 +32,6 @@ n_mel_channels = 100
|
|
32 |
hop_length = 256
|
33 |
win_length = 1024
|
34 |
n_fft = 1024
|
35 |
-
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
36 |
target_rms = 0.1
|
37 |
|
38 |
|
@@ -49,6 +48,7 @@ def main():
|
|
49 |
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
50 |
parser.add_argument("-n", "--expname", required=True)
|
51 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
|
|
52 |
|
53 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
54 |
parser.add_argument("-o", "--odemethod", default="euler")
|
@@ -63,6 +63,7 @@ def main():
|
|
63 |
exp_name = args.expname
|
64 |
ckpt_step = args.ckptstep
|
65 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
|
|
66 |
|
67 |
nfe_step = args.nfestep
|
68 |
ode_method = args.odemethod
|
@@ -101,7 +102,7 @@ def main():
|
|
101 |
output_dir = (
|
102 |
f"{rel_path}/"
|
103 |
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
104 |
-
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
105 |
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
106 |
f"_cfg{cfg_strength}_speed{speed}"
|
107 |
f"{'_gt-dur' if use_truth_duration else ''}"
|
@@ -155,10 +156,10 @@ def main():
|
|
155 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
156 |
if supports_fp16 and mel_spec_type == "vocos":
|
157 |
dtype = torch.float16
|
158 |
-
|
159 |
dtype = torch.float32
|
160 |
|
161 |
-
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
162 |
|
163 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
164 |
os.makedirs(output_dir)
|
|
|
32 |
hop_length = 256
|
33 |
win_length = 1024
|
34 |
n_fft = 1024
|
|
|
35 |
target_rms = 0.1
|
36 |
|
37 |
|
|
|
48 |
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
49 |
parser.add_argument("-n", "--expname", required=True)
|
50 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
51 |
+
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
52 |
|
53 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
54 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
|
63 |
exp_name = args.expname
|
64 |
ckpt_step = args.ckptstep
|
65 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
66 |
+
mel_spec_type = args.mel_spec_type
|
67 |
|
68 |
nfe_step = args.nfestep
|
69 |
ode_method = args.odemethod
|
|
|
102 |
output_dir = (
|
103 |
f"{rel_path}/"
|
104 |
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
105 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
106 |
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
107 |
f"_cfg{cfg_strength}_speed{speed}"
|
108 |
f"{'_gt-dur' if use_truth_duration else ''}"
|
|
|
156 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
157 |
if supports_fp16 and mel_spec_type == "vocos":
|
158 |
dtype = torch.float16
|
159 |
+
elif mel_spec_type == "bigvgan":
|
160 |
dtype = torch.float32
|
161 |
|
162 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
163 |
|
164 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
165 |
os.makedirs(output_dir)
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -154,7 +154,7 @@ elif model == "E2-TTS":
|
|
154 |
|
155 |
|
156 |
print(f"Using {model}...")
|
157 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
|
158 |
|
159 |
|
160 |
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
@@ -192,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
192 |
ref_text = voices[voice]["ref_text"]
|
193 |
print(f"Voice: {voice}")
|
194 |
audio, final_sample_rate, spectragram = infer_process(
|
195 |
-
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed
|
196 |
)
|
197 |
generated_audio_segments.append(audio)
|
198 |
|
|
|
154 |
|
155 |
|
156 |
print(f"Using {model}...")
|
157 |
+
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
|
158 |
|
159 |
|
160 |
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
|
|
192 |
ref_text = voices[voice]["ref_text"]
|
193 |
print(f"Voice: {voice}")
|
194 |
audio, final_sample_rate, spectragram = infer_process(
|
195 |
+
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
|
196 |
)
|
197 |
generated_audio_segments.append(audio)
|
198 |
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -18,7 +18,7 @@ n_mel_channels = 100
|
|
18 |
hop_length = 256
|
19 |
win_length = 1024
|
20 |
n_fft = 1024
|
21 |
-
mel_spec_type = "
|
22 |
target_rms = 0.1
|
23 |
|
24 |
tokenizer = "pinyin"
|
@@ -114,10 +114,10 @@ model = CFM(
|
|
114 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
115 |
if supports_fp16 and mel_spec_type == "vocos":
|
116 |
dtype = torch.float16
|
117 |
-
|
118 |
dtype = torch.float32
|
119 |
|
120 |
-
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
121 |
|
122 |
# Audio
|
123 |
audio, sr = torchaudio.load(audio_to_edit)
|
|
|
18 |
hop_length = 256
|
19 |
win_length = 1024
|
20 |
n_fft = 1024
|
21 |
+
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
22 |
target_rms = 0.1
|
23 |
|
24 |
tokenizer = "pinyin"
|
|
|
114 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
115 |
if supports_fp16 and mel_spec_type == "vocos":
|
116 |
dtype = torch.float16
|
117 |
+
elif mel_spec_type == "bigvgan":
|
118 |
dtype = torch.float32
|
119 |
|
120 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
121 |
|
122 |
# Audio
|
123 |
audio, sr = torchaudio.load(audio_to_edit)
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -40,7 +40,6 @@ n_mel_channels = 100
|
|
40 |
hop_length = 256
|
41 |
win_length = 1024
|
42 |
n_fft = 1024
|
43 |
-
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
44 |
target_rms = 0.1
|
45 |
cross_fade_duration = 0.15
|
46 |
ode_method = "euler"
|
@@ -133,6 +132,10 @@ def initialize_asr_pipeline(device=device):
|
|
133 |
|
134 |
|
135 |
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
|
|
|
|
|
|
|
|
136 |
model = model.to(dtype)
|
137 |
|
138 |
ckpt_type = ckpt_path.split(".")[-1]
|
@@ -169,7 +172,14 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
|
169 |
|
170 |
|
171 |
def load_model(
|
172 |
-
model_cls,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
):
|
174 |
if vocab_file == "":
|
175 |
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
@@ -199,10 +209,10 @@ def load_model(
|
|
199 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
200 |
if supports_fp16 and mel_spec_type == "vocos":
|
201 |
dtype = torch.float16
|
202 |
-
|
203 |
dtype = torch.float32
|
204 |
|
205 |
-
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
206 |
|
207 |
return model
|
208 |
|
@@ -297,7 +307,7 @@ def infer_process(
|
|
297 |
gen_text,
|
298 |
model_obj,
|
299 |
vocoder,
|
300 |
-
mel_spec_type,
|
301 |
show_info=print,
|
302 |
progress=tqdm,
|
303 |
target_rms=target_rms,
|
@@ -323,7 +333,7 @@ def infer_process(
|
|
323 |
gen_text_batches,
|
324 |
model_obj,
|
325 |
vocoder,
|
326 |
-
mel_spec_type,
|
327 |
progress=progress,
|
328 |
target_rms=target_rms,
|
329 |
cross_fade_duration=cross_fade_duration,
|
@@ -345,7 +355,7 @@ def infer_batch_process(
|
|
345 |
gen_text_batches,
|
346 |
model_obj,
|
347 |
vocoder,
|
348 |
-
mel_spec_type,
|
349 |
progress=tqdm,
|
350 |
target_rms=0.1,
|
351 |
cross_fade_duration=0.15,
|
|
|
40 |
hop_length = 256
|
41 |
win_length = 1024
|
42 |
n_fft = 1024
|
|
|
43 |
target_rms = 0.1
|
44 |
cross_fade_duration = 0.15
|
45 |
ode_method = "euler"
|
|
|
132 |
|
133 |
|
134 |
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
135 |
+
if dtype is None:
|
136 |
+
dtype = (
|
137 |
+
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
138 |
+
)
|
139 |
model = model.to(dtype)
|
140 |
|
141 |
ckpt_type = ckpt_path.split(".")[-1]
|
|
|
172 |
|
173 |
|
174 |
def load_model(
|
175 |
+
model_cls,
|
176 |
+
model_cfg,
|
177 |
+
ckpt_path,
|
178 |
+
mel_spec_type="vocos",
|
179 |
+
vocab_file="",
|
180 |
+
ode_method=ode_method,
|
181 |
+
use_ema=True,
|
182 |
+
device=device,
|
183 |
):
|
184 |
if vocab_file == "":
|
185 |
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
|
|
209 |
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
210 |
if supports_fp16 and mel_spec_type == "vocos":
|
211 |
dtype = torch.float16
|
212 |
+
elif mel_spec_type == "bigvgan":
|
213 |
dtype = torch.float32
|
214 |
|
215 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
216 |
|
217 |
return model
|
218 |
|
|
|
307 |
gen_text,
|
308 |
model_obj,
|
309 |
vocoder,
|
310 |
+
mel_spec_type="vocos",
|
311 |
show_info=print,
|
312 |
progress=tqdm,
|
313 |
target_rms=target_rms,
|
|
|
333 |
gen_text_batches,
|
334 |
model_obj,
|
335 |
vocoder,
|
336 |
+
mel_spec_type=mel_spec_type,
|
337 |
progress=progress,
|
338 |
target_rms=target_rms,
|
339 |
cross_fade_duration=cross_fade_duration,
|
|
|
355 |
gen_text_batches,
|
356 |
model_obj,
|
357 |
vocoder,
|
358 |
+
mel_spec_type="vocos",
|
359 |
progress=tqdm,
|
360 |
target_rms=0.1,
|
361 |
cross_fade_duration=0.15,
|