Athspi commited on
Commit
5eaee65
·
verified ·
1 Parent(s): ec99653

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -74
app.py CHANGED
@@ -1,22 +1,36 @@
1
- import spaces
2
- from snac import SNAC
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
 
 
8
  load_dotenv()
9
 
10
- # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  print("Loading SNAC model...")
14
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
15
  snac_model = snac_model.to(device)
 
16
 
17
  model_name = "canopylabs/orpheus-3b-0.1-ft"
18
 
19
- # Download only model config and safetensors
20
  snapshot_download(
21
  repo_id=model_name,
22
  allow_patterns=[
@@ -38,36 +52,34 @@ snapshot_download(
38
  ]
39
  )
40
 
41
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
 
42
  model.to(device)
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
  print(f"Orpheus model loaded to {device}")
45
 
46
- # Process text prompt
47
  def process_prompt(prompt, voice, tokenizer, device):
48
  prompt = f"{voice}: {prompt}"
49
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
-
51
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
52
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
53
-
54
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
55
-
56
- # No padding needed for single input
57
  attention_mask = torch.ones_like(modified_input_ids)
58
-
59
  return modified_input_ids.to(device), attention_mask.to(device)
60
 
61
- # Parse output tokens to audio
62
  def parse_output(generated_ids):
63
  token_to_find = 128257
64
  token_to_remove = 128258
65
-
66
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
67
 
 
68
  if len(token_indices[1]) > 0:
69
  last_occurrence_idx = token_indices[1][-1].item()
70
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
71
  else:
72
  cropped_tensor = generated_ids
73
 
@@ -83,47 +95,43 @@ def parse_output(generated_ids):
83
  trimmed_row = row[:new_length]
84
  trimmed_row = [t - 128266 for t in trimmed_row]
85
  code_lists.append(trimmed_row)
86
-
87
- return code_lists[0] # Return just the first one for single sample
88
 
89
- # Redistribute codes for audio generation
 
 
90
  def redistribute_codes(code_list, snac_model):
91
- device = next(snac_model.parameters()).device # Get the device of SNAC model
92
-
93
- layer_1 = []
94
- layer_2 = []
95
- layer_3 = []
96
- for i in range((len(code_list)+1)//7):
97
- layer_1.append(code_list[7*i])
98
- layer_2.append(code_list[7*i+1]-4096)
99
- layer_3.append(code_list[7*i+2]-(2*4096))
100
- layer_3.append(code_list[7*i+3]-(3*4096))
101
- layer_2.append(code_list[7*i+4]-(4*4096))
102
- layer_3.append(code_list[7*i+5]-(5*4096))
103
- layer_3.append(code_list[7*i+6]-(6*4096))
104
-
105
- # Move tensors to the same device as the SNAC model
106
  codes = [
107
- torch.tensor(layer_1, device=device).unsqueeze(0),
108
- torch.tensor(layer_2, device=device).unsqueeze(0),
109
- torch.tensor(layer_3, device=device).unsqueeze(0)
110
  ]
111
-
112
  audio_hat = snac_model.decode(codes)
113
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
114
 
115
- # Main generation function
116
- @spaces.GPU()
117
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
118
  if not text.strip():
119
  return None
120
-
121
  try:
122
  progress(0.1, "Processing text...")
123
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
124
-
125
  progress(0.3, "Generating speech tokens...")
126
- with torch.no_grad():
127
  generated_ids = model.generate(
128
  input_ids=input_ids,
129
  attention_mask=attention_mask,
@@ -135,39 +143,38 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
135
  num_return_sequences=1,
136
  eos_token_id=128258,
137
  )
138
-
139
  progress(0.6, "Processing speech tokens...")
140
  code_list = parse_output(generated_ids)
141
-
142
- progress(0.8, "Converting to audio...")
143
  audio_samples = redistribute_codes(code_list, snac_model)
144
-
145
- return (24000, audio_samples) # Return sample rate and audio
146
  except Exception as e:
147
  print(f"Error generating speech: {e}")
148
  return None
149
 
150
- # Examples for the UI
151
  examples = [
152
  ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
153
  ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
154
- ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, lets just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
155
  ]
156
 
157
- # Available voices
158
  VOICES = ["tara", "dan", "josh", "emma"]
159
 
160
  # Create Gradio interface
161
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
162
  gr.Markdown("""
163
- # 🎵 [Orpheus Text-to-Speech](https://github.com/canopyai/Orpheus-TTS)
164
- Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model.
165
 
166
- ## Tips for better prompts:
167
- - Add paralinguistic elements like `<chuckle>`, `<sigh>`, or `uhm` for more human-like speech.
168
- - Longer text prompts generally work better than very short phrases
169
- - Adjust the temperature slider for more varied (higher) or consistent (lower) speech patterns
170
- """)
171
  with gr.Row():
172
  with gr.Column(scale=3):
173
  text_input = gr.Textbox(
@@ -180,37 +187,33 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
180
  value="tara",
181
  label="Voice"
182
  )
183
-
184
  with gr.Accordion("Advanced Settings", open=False):
185
  temperature = gr.Slider(
186
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
187
- label="Temperature",
188
  info="Higher values (0.7-1.0) create more expressive but less stable speech"
189
  )
190
  top_p = gr.Slider(
191
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
192
- label="Top P",
193
  info="Nucleus sampling threshold"
194
  )
195
  repetition_penalty = gr.Slider(
196
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
197
- label="Repetition Penalty",
198
  info="Higher values discourage repetitive patterns"
199
  )
200
  max_new_tokens = gr.Slider(
201
  minimum=100, maximum=2000, value=1200, step=100,
202
- label="Max Length",
203
  info="Maximum length of generated audio (in tokens)"
204
  )
205
-
206
  with gr.Row():
207
  submit_btn = gr.Button("Generate Speech", variant="primary")
208
  clear_btn = gr.Button("Clear")
209
-
210
  with gr.Column(scale=2):
211
  audio_output = gr.Audio(label="Generated Speech", type="numpy")
212
-
213
- # Set up examples
214
  gr.Examples(
215
  examples=examples,
216
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
@@ -219,19 +222,17 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
219
  cache_examples=True,
220
  )
221
 
222
- # Set up event handlers
223
  submit_btn.click(
224
  fn=generate_speech,
225
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
226
  outputs=audio_output
227
  )
228
-
229
  clear_btn.click(
230
  fn=lambda: (None, None),
231
  inputs=[],
232
  outputs=[text_input, audio_output]
233
  )
234
 
235
- # Launch the app
236
  if __name__ == "__main__":
237
- demo.queue().launch(share=False, ssr_mode=False)
 
1
+ import os
 
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import snapshot_download
6
  from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
  load_dotenv()
10
 
11
+ # Device and torch dtype selection
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
14
+
15
+ # Define a no-op decorator for CPU if needed
16
+ def gpu_decorator(func):
17
+ return func
18
+
19
+ # If you are on GPU and have the spaces module, you could replace gpu_decorator with spaces.GPU
20
+ # For CPU usage we simply use a no-op
21
+ # Example: from snac import spaces; gpu_decorator = spaces.GPU()
22
+
23
+ # Import SNAC after setting device
24
+ from snac import SNAC
25
 
26
  print("Loading SNAC model...")
27
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
28
  snac_model = snac_model.to(device)
29
+ snac_model.eval() # set SNAC to eval mode
30
 
31
  model_name = "canopylabs/orpheus-3b-0.1-ft"
32
 
33
+ # Download only model config and safetensors files
34
  snapshot_download(
35
  repo_id=model_name,
36
  allow_patterns=[
 
52
  ]
53
  )
54
 
55
+ print("Loading Orpheus model...")
56
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
57
  model.to(device)
58
+ model.eval() # set Orpheus to eval mode
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  print(f"Orpheus model loaded to {device}")
61
 
62
+ # Process text prompt into tokens with start/end markers
63
  def process_prompt(prompt, voice, tokenizer, device):
64
  prompt = f"{voice}: {prompt}"
65
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
66
+
67
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start token
68
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End tokens
69
+
70
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
 
 
71
  attention_mask = torch.ones_like(modified_input_ids)
 
72
  return modified_input_ids.to(device), attention_mask.to(device)
73
 
74
+ # Parse output tokens to extract audio codes
75
  def parse_output(generated_ids):
76
  token_to_find = 128257
77
  token_to_remove = 128258
 
 
78
 
79
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
80
  if len(token_indices[1]) > 0:
81
  last_occurrence_idx = token_indices[1][-1].item()
82
+ cropped_tensor = generated_ids[:, last_occurrence_idx + 1:]
83
  else:
84
  cropped_tensor = generated_ids
85
 
 
95
  trimmed_row = row[:new_length]
96
  trimmed_row = [t - 128266 for t in trimmed_row]
97
  code_lists.append(trimmed_row)
 
 
98
 
99
+ return code_lists[0] # Return first sample
100
+
101
+ # Redistribute codes for audio generation using SNAC
102
  def redistribute_codes(code_list, snac_model):
103
+ snac_device = next(snac_model.parameters()).device
104
+ layer_1, layer_2, layer_3 = [], [], []
105
+ for i in range((len(code_list) + 1) // 7):
106
+ layer_1.append(code_list[7 * i])
107
+ layer_2.append(code_list[7 * i + 1] - 4096)
108
+ layer_3.append(code_list[7 * i + 2] - (2 * 4096))
109
+ layer_3.append(code_list[7 * i + 3] - (3 * 4096))
110
+ layer_2.append(code_list[7 * i + 4] - (4 * 4096))
111
+ layer_3.append(code_list[7 * i + 5] - (5 * 4096))
112
+ layer_3.append(code_list[7 * i + 6] - (6 * 4096))
113
+
 
 
 
 
114
  codes = [
115
+ torch.tensor(layer_1, device=snac_device).unsqueeze(0),
116
+ torch.tensor(layer_2, device=snac_device).unsqueeze(0),
117
+ torch.tensor(layer_3, device=snac_device).unsqueeze(0)
118
  ]
119
+
120
  audio_hat = snac_model.decode(codes)
121
+ return audio_hat.detach().squeeze().cpu().numpy()
122
 
123
+ # Main generation function with CPU optimizations
124
+ @gpu_decorator
125
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
126
  if not text.strip():
127
  return None
128
+
129
  try:
130
  progress(0.1, "Processing text...")
131
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
132
+
133
  progress(0.3, "Generating speech tokens...")
134
+ with torch.inference_mode():
135
  generated_ids = model.generate(
136
  input_ids=input_ids,
137
  attention_mask=attention_mask,
 
143
  num_return_sequences=1,
144
  eos_token_id=128258,
145
  )
146
+
147
  progress(0.6, "Processing speech tokens...")
148
  code_list = parse_output(generated_ids)
149
+
150
+ progress(0.8, "Converting tokens to audio...")
151
  audio_samples = redistribute_codes(code_list, snac_model)
152
+
153
+ return (24000, audio_samples) # Return sample rate and numpy array audio
154
  except Exception as e:
155
  print(f"Error generating speech: {e}")
156
  return None
157
 
158
+ # Example inputs for the Gradio UI
159
  examples = [
160
  ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
161
  ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
162
+ ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, let's just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
163
  ]
164
 
 
165
  VOICES = ["tara", "dan", "josh", "emma"]
166
 
167
  # Create Gradio interface
168
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
169
  gr.Markdown("""
170
+ # 🎵 Orpheus Text-to-Speech
171
+ Enter your text below and hear it converted to natural-sounding speech.
172
 
173
+ **Tips for better prompts:**
174
+ - Include paralinguistic elements like `<chuckle>`, `<sigh>`, or `uhm` for more human-like speech.
175
+ - Longer prompts often produce more natural results.
176
+ - Adjust the temperature slider to control variation in speech patterns.
177
+ """)
178
  with gr.Row():
179
  with gr.Column(scale=3):
180
  text_input = gr.Textbox(
 
187
  value="tara",
188
  label="Voice"
189
  )
 
190
  with gr.Accordion("Advanced Settings", open=False):
191
  temperature = gr.Slider(
192
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
193
+ label="Temperature",
194
  info="Higher values (0.7-1.0) create more expressive but less stable speech"
195
  )
196
  top_p = gr.Slider(
197
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
198
+ label="Top P",
199
  info="Nucleus sampling threshold"
200
  )
201
  repetition_penalty = gr.Slider(
202
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
203
+ label="Repetition Penalty",
204
  info="Higher values discourage repetitive patterns"
205
  )
206
  max_new_tokens = gr.Slider(
207
  minimum=100, maximum=2000, value=1200, step=100,
208
+ label="Max Length",
209
  info="Maximum length of generated audio (in tokens)"
210
  )
 
211
  with gr.Row():
212
  submit_btn = gr.Button("Generate Speech", variant="primary")
213
  clear_btn = gr.Button("Clear")
 
214
  with gr.Column(scale=2):
215
  audio_output = gr.Audio(label="Generated Speech", type="numpy")
216
+
 
217
  gr.Examples(
218
  examples=examples,
219
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
 
222
  cache_examples=True,
223
  )
224
 
 
225
  submit_btn.click(
226
  fn=generate_speech,
227
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
228
  outputs=audio_output
229
  )
 
230
  clear_btn.click(
231
  fn=lambda: (None, None),
232
  inputs=[],
233
  outputs=[text_input, audio_output]
234
  )
235
 
236
+ # Launch the Gradio app
237
  if __name__ == "__main__":
238
+ demo.queue().launch(share=False, ssr_mode=False)