Add cache feature. Retrieve previous generated segments, default cache size 100
Browse files- src/f5_tts/infer/infer_gradio.py +12 -12
src/f5_tts/infer/infer_gradio.py
CHANGED
@@ -6,6 +6,7 @@ import json
|
|
6 |
import re
|
7 |
import tempfile
|
8 |
from collections import OrderedDict
|
|
|
9 |
from importlib.resources import files
|
10 |
|
11 |
import click
|
@@ -122,6 +123,7 @@ def load_text_from_file(file):
|
|
122 |
return gr.update(value=text)
|
123 |
|
124 |
|
|
|
125 |
@gpu_decorator
|
126 |
def infer(
|
127 |
ref_audio_orig,
|
@@ -129,7 +131,7 @@ def infer(
|
|
129 |
gen_text,
|
130 |
model,
|
131 |
remove_silence,
|
132 |
-
seed
|
133 |
cross_fade_duration=0.15,
|
134 |
nfe_step=32,
|
135 |
speed=1,
|
@@ -140,9 +142,7 @@ def infer(
|
|
140 |
return gr.update(), gr.update(), ref_text
|
141 |
|
142 |
# Set inference seed
|
143 |
-
if seed
|
144 |
-
seed = np.random.randint(0, 2**31 - 1)
|
145 |
-
elif seed < 0 or seed > 2**31 - 1:
|
146 |
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
|
147 |
seed = np.random.randint(0, 2**31 - 1)
|
148 |
torch.manual_seed(seed)
|
@@ -284,7 +284,7 @@ with gr.Blocks() as app_tts:
|
|
284 |
speed_slider,
|
285 |
):
|
286 |
if randomize_seed:
|
287 |
-
seed_input =
|
288 |
|
289 |
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
|
290 |
ref_audio_input,
|
@@ -620,7 +620,7 @@ with gr.Blocks() as app_multistyle:
|
|
620 |
|
621 |
for segment in segments:
|
622 |
name = segment["name"]
|
623 |
-
|
624 |
speed = segment["speed"]
|
625 |
text = segment["text"]
|
626 |
|
@@ -637,10 +637,10 @@ with gr.Blocks() as app_multistyle:
|
|
637 |
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
638 |
ref_text = speech_types[current_type_name].get("ref_text", "")
|
639 |
|
640 |
-
if
|
641 |
-
seed_input =
|
642 |
|
643 |
-
# Generate speech for this segment
|
644 |
audio_out, _, ref_text_out, used_seed = infer(
|
645 |
ref_audio,
|
646 |
ref_text,
|
@@ -650,8 +650,8 @@ with gr.Blocks() as app_multistyle:
|
|
650 |
seed=seed_input,
|
651 |
cross_fade_duration=0,
|
652 |
speed=speed,
|
653 |
-
show_info=print,
|
654 |
-
)
|
655 |
sr, audio_data = audio_out
|
656 |
|
657 |
generated_audio_segments.append(audio_data)
|
@@ -863,7 +863,7 @@ Have a conversation with an AI using your reference voice!
|
|
863 |
return None, ref_text, seed_input
|
864 |
|
865 |
if randomize_seed:
|
866 |
-
seed_input =
|
867 |
|
868 |
audio_result, _, ref_text_out, used_seed = infer(
|
869 |
ref_audio,
|
|
|
6 |
import re
|
7 |
import tempfile
|
8 |
from collections import OrderedDict
|
9 |
+
from functools import lru_cache
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import click
|
|
|
123 |
return gr.update(value=text)
|
124 |
|
125 |
|
126 |
+
@lru_cache(maxsize=100)
|
127 |
@gpu_decorator
|
128 |
def infer(
|
129 |
ref_audio_orig,
|
|
|
131 |
gen_text,
|
132 |
model,
|
133 |
remove_silence,
|
134 |
+
seed,
|
135 |
cross_fade_duration=0.15,
|
136 |
nfe_step=32,
|
137 |
speed=1,
|
|
|
142 |
return gr.update(), gr.update(), ref_text
|
143 |
|
144 |
# Set inference seed
|
145 |
+
if seed < 0 or seed > 2**31 - 1:
|
|
|
|
|
146 |
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
|
147 |
seed = np.random.randint(0, 2**31 - 1)
|
148 |
torch.manual_seed(seed)
|
|
|
284 |
speed_slider,
|
285 |
):
|
286 |
if randomize_seed:
|
287 |
+
seed_input = np.random.randint(0, 2**31 - 1)
|
288 |
|
289 |
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
|
290 |
ref_audio_input,
|
|
|
620 |
|
621 |
for segment in segments:
|
622 |
name = segment["name"]
|
623 |
+
seed_input = segment["seed"]
|
624 |
speed = segment["speed"]
|
625 |
text = segment["text"]
|
626 |
|
|
|
637 |
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
638 |
ref_text = speech_types[current_type_name].get("ref_text", "")
|
639 |
|
640 |
+
if seed_input == -1:
|
641 |
+
seed_input = np.random.randint(0, 2**31 - 1)
|
642 |
|
643 |
+
# Generate or retrieve speech for this segment
|
644 |
audio_out, _, ref_text_out, used_seed = infer(
|
645 |
ref_audio,
|
646 |
ref_text,
|
|
|
650 |
seed=seed_input,
|
651 |
cross_fade_duration=0,
|
652 |
speed=speed,
|
653 |
+
show_info=print, # no pull to top when generating
|
654 |
+
)
|
655 |
sr, audio_data = audio_out
|
656 |
|
657 |
generated_audio_segments.append(audio_data)
|
|
|
863 |
return None, ref_text, seed_input
|
864 |
|
865 |
if randomize_seed:
|
866 |
+
seed_input = np.random.randint(0, 2**31 - 1)
|
867 |
|
868 |
audio_result, _, ref_text_out, used_seed = infer(
|
869 |
ref_audio,
|