ollieollie commited on
Commit
22418d0
·
verified ·
1 Parent(s): 5aefb03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -17
app.py CHANGED
@@ -1,30 +1,65 @@
 
 
 
1
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
2
  import gradio as gr
3
 
 
4
 
5
- model = ChatterboxTTS.from_pretrained("cuda")
6
 
7
- def generate(text, audio_prompt_path, exaggeration, pace, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  wav = model.generate(
9
- text, audio_prompt_path=audio_prompt_path,
 
10
  exaggeration=exaggeration,
11
  pace=pace,
12
  temperature=temperature,
13
  )
14
- return 24000, wav.squeeze(0).numpy()
15
-
16
-
17
- demo = gr.Interface(
18
- generate,
19
- [
20
- gr.Textbox(value="", label="Text to synthesize"),
21
- gr.Audio(sources="upload", type="filepath", label="Reference Audio File", value=None),
22
- gr.Slider(-5, 5, step=.05, label="exaggeration", value=.5),
23
- gr.Slider(0.8, 1.2, step=.01, label="pace", value=1),
24
- gr.Slider(0.05, 5, step=.05, label="temperature", value=.8),
25
- ],
26
- "audio",
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  if __name__ == "__main__":
30
  demo.launch()
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
  import gradio as gr
6
 
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
9
 
10
+ def set_seed(seed: int):
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+
17
+
18
+ model = ChatterboxTTS.from_pretrained(DEVICE)
19
+
20
+ def generate(text, audio_prompt_path, exaggeration, pace, temperature, seed_num):
21
+ if seed_num != 0:
22
+ set_seed(int(seed_num))
23
+
24
  wav = model.generate(
25
+ text,
26
+ audio_prompt_path=audio_prompt_path,
27
  exaggeration=exaggeration,
28
  pace=pace,
29
  temperature=temperature,
30
  )
31
+ return model.sr, wav.squeeze(0).numpy()
32
+
33
+
34
+ with gr.Blocks() as demo:
35
+ with gr.Row():
36
+ with gr.Column():
37
+ text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
38
+ ref_wav = gr.Audio(sources="upload", type="filepath", label="Reference Audio File", value=None)
39
+ exaggeration = gr.Slider(-5, 5, step=.05, label="exaggeration", value=.5)
40
+
41
+ with gr.Accordion("More options", open=False):
42
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
43
+ temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
44
+ pace = gr.Slider(0.8, 1.2, step=.01, label="pace", value=1)
45
+
46
+ run_btn = gr.Button("Generate", variant="primary")
47
+
48
+ with gr.Column():
49
+ audio_output = gr.Audio(label="Output Audio")
50
+
51
+ run_btn.click(
52
+ fn=generate,
53
+ inputs=[
54
+ text,
55
+ ref_wav,
56
+ exaggeration,
57
+ pace,
58
+ temp,
59
+ seed_num,
60
+ ],
61
+ outputs=audio_output,
62
+ )
63
 
64
  if __name__ == "__main__":
65
  demo.launch()