Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +1 -0
- model/utils_infer.py +7 -23
api.py
CHANGED
|
@@ -105,6 +105,7 @@ class F5TTS:
|
|
| 105 |
sway_sampling_coef=sway_sampling_coef,
|
| 106 |
speed=speed,
|
| 107 |
fix_duration=fix_duration,
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
if file_wave is not None:
|
|
|
|
| 105 |
sway_sampling_coef=sway_sampling_coef,
|
| 106 |
speed=speed,
|
| 107 |
fix_duration=fix_duration,
|
| 108 |
+
device=self.device,
|
| 109 |
)
|
| 110 |
|
| 111 |
if file_wave is not None:
|
model/utils_infer.py
CHANGED
|
@@ -19,13 +19,8 @@ from model.utils import (
|
|
| 19 |
convert_char_to_pinyin,
|
| 20 |
)
|
| 21 |
|
| 22 |
-
# get device
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_device():
|
| 26 |
-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 27 |
-
return device
|
| 28 |
|
|
|
|
| 29 |
|
| 30 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 31 |
|
|
@@ -81,9 +76,7 @@ def chunk_text(text, max_chars=135):
|
|
| 81 |
|
| 82 |
|
| 83 |
# load vocoder
|
| 84 |
-
def load_vocoder(is_local=False, local_path="", device=
|
| 85 |
-
if device is None:
|
| 86 |
-
device = get_device()
|
| 87 |
if is_local:
|
| 88 |
print(f"Load vocos from local path {local_path}")
|
| 89 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
@@ -101,11 +94,8 @@ def load_vocoder(is_local=False, local_path="", device=None):
|
|
| 101 |
asr_pipe = None
|
| 102 |
|
| 103 |
|
| 104 |
-
def initialize_asr_pipeline(device=
|
| 105 |
global asr_pipe
|
| 106 |
-
if device is None:
|
| 107 |
-
device = get_device()
|
| 108 |
-
|
| 109 |
asr_pipe = pipeline(
|
| 110 |
"automatic-speech-recognition",
|
| 111 |
model="openai/whisper-large-v3-turbo",
|
|
@@ -117,9 +107,7 @@ def initialize_asr_pipeline(device=None):
|
|
| 117 |
# load model for inference
|
| 118 |
|
| 119 |
|
| 120 |
-
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=
|
| 121 |
-
if device is None:
|
| 122 |
-
device = get_device()
|
| 123 |
if vocab_file == "":
|
| 124 |
vocab_file = "Emilia_ZH_EN"
|
| 125 |
tokenizer = "pinyin"
|
|
@@ -152,10 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
|
|
| 152 |
# preprocess reference audio and text
|
| 153 |
|
| 154 |
|
| 155 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
|
| 156 |
-
if device is None:
|
| 157 |
-
device = get_device()
|
| 158 |
-
|
| 159 |
show_info("Converting audio...")
|
| 160 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 161 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
@@ -216,6 +201,7 @@ def infer_process(
|
|
| 216 |
sway_sampling_coef=sway_sampling_coef,
|
| 217 |
speed=speed,
|
| 218 |
fix_duration=fix_duration,
|
|
|
|
| 219 |
):
|
| 220 |
# Split the input text into batches
|
| 221 |
audio, sr = torchaudio.load(ref_audio)
|
|
@@ -238,6 +224,7 @@ def infer_process(
|
|
| 238 |
sway_sampling_coef=sway_sampling_coef,
|
| 239 |
speed=speed,
|
| 240 |
fix_duration=fix_duration,
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
|
@@ -259,9 +246,6 @@ def infer_batch_process(
|
|
| 259 |
fix_duration=None,
|
| 260 |
device=None,
|
| 261 |
):
|
| 262 |
-
if device is None:
|
| 263 |
-
device = get_device()
|
| 264 |
-
|
| 265 |
audio, sr = ref_audio
|
| 266 |
if audio.shape[0] > 1:
|
| 267 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
|
| 19 |
convert_char_to_pinyin,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 24 |
|
| 25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 26 |
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
# load vocoder
|
| 79 |
+
def load_vocoder(is_local=False, local_path="", device=device):
|
|
|
|
|
|
|
| 80 |
if is_local:
|
| 81 |
print(f"Load vocos from local path {local_path}")
|
| 82 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
|
|
| 94 |
asr_pipe = None
|
| 95 |
|
| 96 |
|
| 97 |
+
def initialize_asr_pipeline(device=device):
|
| 98 |
global asr_pipe
|
|
|
|
|
|
|
|
|
|
| 99 |
asr_pipe = pipeline(
|
| 100 |
"automatic-speech-recognition",
|
| 101 |
model="openai/whisper-large-v3-turbo",
|
|
|
|
| 107 |
# load model for inference
|
| 108 |
|
| 109 |
|
| 110 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
|
|
|
|
|
|
|
| 111 |
if vocab_file == "":
|
| 112 |
vocab_file = "Emilia_ZH_EN"
|
| 113 |
tokenizer = "pinyin"
|
|
|
|
| 140 |
# preprocess reference audio and text
|
| 141 |
|
| 142 |
|
| 143 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
|
|
|
|
|
|
|
|
|
|
| 144 |
show_info("Converting audio...")
|
| 145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 146 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
|
| 201 |
sway_sampling_coef=sway_sampling_coef,
|
| 202 |
speed=speed,
|
| 203 |
fix_duration=fix_duration,
|
| 204 |
+
device=device,
|
| 205 |
):
|
| 206 |
# Split the input text into batches
|
| 207 |
audio, sr = torchaudio.load(ref_audio)
|
|
|
|
| 224 |
sway_sampling_coef=sway_sampling_coef,
|
| 225 |
speed=speed,
|
| 226 |
fix_duration=fix_duration,
|
| 227 |
+
device=device,
|
| 228 |
)
|
| 229 |
|
| 230 |
|
|
|
|
| 246 |
fix_duration=None,
|
| 247 |
device=None,
|
| 248 |
):
|
|
|
|
|
|
|
|
|
|
| 249 |
audio, sr = ref_audio
|
| 250 |
if audio.shape[0] > 1:
|
| 251 |
audio = torch.mean(audio, dim=0, keepdim=True)
|