SWivid commited on
Commit
6bb2043
·
1 Parent(s): c34cc32

Add cache feature. Retrieve previous generated segments, default cache size 100

Browse files
Files changed (1) hide show
  1. src/f5_tts/infer/infer_gradio.py +12 -12
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import re
7
  import tempfile
8
  from collections import OrderedDict
 
9
  from importlib.resources import files
10
 
11
  import click
@@ -122,6 +123,7 @@ def load_text_from_file(file):
122
  return gr.update(value=text)
123
 
124
 
 
125
  @gpu_decorator
126
  def infer(
127
  ref_audio_orig,
@@ -129,7 +131,7 @@ def infer(
129
  gen_text,
130
  model,
131
  remove_silence,
132
- seed=None,
133
  cross_fade_duration=0.15,
134
  nfe_step=32,
135
  speed=1,
@@ -140,9 +142,7 @@ def infer(
140
  return gr.update(), gr.update(), ref_text
141
 
142
  # Set inference seed
143
- if seed is None:
144
- seed = np.random.randint(0, 2**31 - 1)
145
- elif seed < 0 or seed > 2**31 - 1:
146
  gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
147
  seed = np.random.randint(0, 2**31 - 1)
148
  torch.manual_seed(seed)
@@ -284,7 +284,7 @@ with gr.Blocks() as app_tts:
284
  speed_slider,
285
  ):
286
  if randomize_seed:
287
- seed_input = None
288
 
289
  audio_out, spectrogram_path, ref_text_out, used_seed = infer(
290
  ref_audio_input,
@@ -620,7 +620,7 @@ with gr.Blocks() as app_multistyle:
620
 
621
  for segment in segments:
622
  name = segment["name"]
623
- seed = segment["seed"]
624
  speed = segment["speed"]
625
  text = segment["text"]
626
 
@@ -637,10 +637,10 @@ with gr.Blocks() as app_multistyle:
637
  return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
638
  ref_text = speech_types[current_type_name].get("ref_text", "")
639
 
640
- if seed == -1:
641
- seed_input = None
642
 
643
- # Generate speech for this segment
644
  audio_out, _, ref_text_out, used_seed = infer(
645
  ref_audio,
646
  ref_text,
@@ -650,8 +650,8 @@ with gr.Blocks() as app_multistyle:
650
  seed=seed_input,
651
  cross_fade_duration=0,
652
  speed=speed,
653
- show_info=print,
654
- ) # show_info=print no pull to top when generating
655
  sr, audio_data = audio_out
656
 
657
  generated_audio_segments.append(audio_data)
@@ -863,7 +863,7 @@ Have a conversation with an AI using your reference voice!
863
  return None, ref_text, seed_input
864
 
865
  if randomize_seed:
866
- seed_input = None
867
 
868
  audio_result, _, ref_text_out, used_seed = infer(
869
  ref_audio,
 
6
  import re
7
  import tempfile
8
  from collections import OrderedDict
9
+ from functools import lru_cache
10
  from importlib.resources import files
11
 
12
  import click
 
123
  return gr.update(value=text)
124
 
125
 
126
+ @lru_cache(maxsize=100)
127
  @gpu_decorator
128
  def infer(
129
  ref_audio_orig,
 
131
  gen_text,
132
  model,
133
  remove_silence,
134
+ seed,
135
  cross_fade_duration=0.15,
136
  nfe_step=32,
137
  speed=1,
 
142
  return gr.update(), gr.update(), ref_text
143
 
144
  # Set inference seed
145
+ if seed < 0 or seed > 2**31 - 1:
 
 
146
  gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
147
  seed = np.random.randint(0, 2**31 - 1)
148
  torch.manual_seed(seed)
 
284
  speed_slider,
285
  ):
286
  if randomize_seed:
287
+ seed_input = np.random.randint(0, 2**31 - 1)
288
 
289
  audio_out, spectrogram_path, ref_text_out, used_seed = infer(
290
  ref_audio_input,
 
620
 
621
  for segment in segments:
622
  name = segment["name"]
623
+ seed_input = segment["seed"]
624
  speed = segment["speed"]
625
  text = segment["text"]
626
 
 
637
  return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
638
  ref_text = speech_types[current_type_name].get("ref_text", "")
639
 
640
+ if seed_input == -1:
641
+ seed_input = np.random.randint(0, 2**31 - 1)
642
 
643
+ # Generate or retrieve speech for this segment
644
  audio_out, _, ref_text_out, used_seed = infer(
645
  ref_audio,
646
  ref_text,
 
650
  seed=seed_input,
651
  cross_fade_duration=0,
652
  speed=speed,
653
+ show_info=print, # no pull to top when generating
654
+ )
655
  sr, audio_data = audio_out
656
 
657
  generated_audio_segments.append(audio_data)
 
863
  return None, ref_text, seed_input
864
 
865
  if randomize_seed:
866
+ seed_input = np.random.randint(0, 2**31 - 1)
867
 
868
  audio_result, _, ref_text_out, used_seed = infer(
869
  ref_audio,