Staticaliza commited on
Commit
30f8a20
·
verified ·
1 Parent(s): 9a7a13d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -1,19 +1,19 @@
1
  # app.py
2
- import os, gradio as gr
3
  from huggingface_hub import snapshot_download
4
  from indextts.infer import IndexTTS
5
 
6
  model_dir = snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints", local_dir_use_symlinks=False)
7
- cfg_path = os.path.join(model_dir, "config.yaml")
8
 
9
  tts = None
10
- def _ensure_loaded():
11
  global tts
12
- if tts is None:
13
- tts = IndexTTS(model_dir=model_dir, cfg_path=cfg_path)
14
 
 
15
  def synth(ref_wav, prompt):
16
- _ensure_loaded()
17
  out = "out.wav"
18
  tts.infer(ref_wav, prompt, out)
19
  return out
@@ -22,8 +22,8 @@ with gr.Blocks() as demo:
22
  gr.Markdown("# index-tts 1.5 zerogpu")
23
  txt = gr.Textbox(label="text prompt")
24
  ref = gr.Audio(label="reference voice", type="filepath")
25
- out = gr.Audio(label="generated speech", type="filepath")
26
- gr.Button("generate").click(synth, [ref, txt], out)
27
 
28
  demo.queue()
29
- demo.launch(show_api=False, ssr_mode=False, max_file_size=50*1024*1024)
 
1
  # app.py
2
+ import os, gradio as gr, spaces
3
  from huggingface_hub import snapshot_download
4
  from indextts.infer import IndexTTS
5
 
6
  model_dir = snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints", local_dir_use_symlinks=False)
7
+ cfg_path = os.path.join(model_dir, "config.yaml")
8
 
9
  tts = None
10
+ def load():
11
  global tts
12
+ if tts is None: tts = IndexTTS(model_dir=model_dir, cfg_path=cfg_path)
 
13
 
14
+ @spaces.GPU
15
  def synth(ref_wav, prompt):
16
+ load()
17
  out = "out.wav"
18
  tts.infer(ref_wav, prompt, out)
19
  return out
 
22
  gr.Markdown("# index-tts 1.5 zerogpu")
23
  txt = gr.Textbox(label="text prompt")
24
  ref = gr.Audio(label="reference voice", type="filepath")
25
+ gen = gr.Audio(label="generated speech", type="filepath")
26
+ gr.Button("generate").click(synth, [ref, txt], gen)
27
 
28
  demo.queue()
29
+ demo.launch(show_api=False, ssr_mode=False)