bluenevus commited on
Commit
3f87519
·
verified ·
1 Parent(s): 62246c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -21
app.py CHANGED
@@ -136,35 +136,41 @@ def parse_output(generated_ids):
136
  return code_lists[0]
137
 
138
  def redistribute_codes(code_list, snac_model):
139
- device = next(snac_model.parameters()).device
140
-
141
- layer_1, layer_2, layer_3 = [], [], []
142
- for i in range((len(code_list)+1)//7):
143
- layer_1.append(code_list[7*i])
144
- layer_2.append(code_list[7*i+1]-4096)
145
- layer_3.append(code_list[7*i+2]-(2*4096))
146
- layer_3.append(code_list[7*i+3]-(3*4096))
147
- layer_2.append(code_list[7*i+4]-(4*4096))
148
- layer_3.append(code_list[7*i+5]-(5*4096))
149
- layer_3.append(code_list[7*i+6]-(6*4096))
150
-
151
- codes = [
152
- torch.tensor(layer_1, device=device).unsqueeze(0),
153
- torch.tensor(layer_2, device=device).unsqueeze(0),
154
- torch.tensor(layer_3, device=device).unsqueeze(0)
155
- ]
156
-
157
- audio_hat = snac_model.decode(codes)
158
- return audio_hat.detach().squeeze().cpu().numpy()
 
 
 
 
159
 
160
  @spaces.GPU()
161
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
162
  if not text.strip():
 
163
  return None
164
 
165
  try:
166
  progress(0.1, "Processing text...")
167
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
168
 
169
  progress(0.3, "Generating speech tokens...")
170
  with torch.no_grad():
@@ -179,16 +185,31 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
179
  num_return_sequences=1,
180
  eos_token_id=128258,
181
  )
 
182
 
183
  progress(0.6, "Processing speech tokens...")
184
  code_list = parse_output(generated_ids)
 
 
 
 
 
185
 
186
  progress(0.8, "Converting to audio...")
187
  audio_samples = redistribute_codes(code_list, snac_model)
188
 
 
 
 
 
 
 
 
 
 
189
  return (24000, audio_samples) # Return sample rate and audio
190
  except Exception as e:
191
- print(f"Error generating speech: {e}")
192
  return None
193
 
194
  @spaces.GPU()
 
136
  return code_lists[0]
137
 
138
  def redistribute_codes(code_list, snac_model):
139
+ try:
140
+ device = next(snac_model.parameters()).device
141
+
142
+ layer_1, layer_2, layer_3 = [], [], []
143
+ for i in range((len(code_list)+1)//7):
144
+ layer_1.append(code_list[7*i])
145
+ layer_2.append(code_list[7*i+1]-4096)
146
+ layer_3.append(code_list[7*i+2]-(2*4096))
147
+ layer_3.append(code_list[7*i+3]-(3*4096))
148
+ layer_2.append(code_list[7*i+4]-(4*4096))
149
+ layer_3.append(code_list[7*i+5]-(5*4096))
150
+ layer_3.append(code_list[7*i+6]-(6*4096))
151
+
152
+ codes = [
153
+ torch.tensor(layer_1, device=device).unsqueeze(0),
154
+ torch.tensor(layer_2, device=device).unsqueeze(0),
155
+ torch.tensor(layer_3, device=device).unsqueeze(0)
156
+ ]
157
+
158
+ audio_hat = snac_model.decode(codes)
159
+ return audio_hat.detach().squeeze().cpu().numpy()
160
+ except Exception as e:
161
+ logger.error(f"Error in redistribute_codes: {e}", exc_info=True)
162
+ return None
163
 
164
  @spaces.GPU()
165
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
166
  if not text.strip():
167
+ logger.warning("Empty text input. Skipping speech generation.")
168
  return None
169
 
170
  try:
171
  progress(0.1, "Processing text...")
172
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
173
+ logger.info(f"Input shape: {input_ids.shape}")
174
 
175
  progress(0.3, "Generating speech tokens...")
176
  with torch.no_grad():
 
185
  num_return_sequences=1,
186
  eos_token_id=128258,
187
  )
188
+ logger.info(f"Generated shape: {generated_ids.shape}")
189
 
190
  progress(0.6, "Processing speech tokens...")
191
  code_list = parse_output(generated_ids)
192
+ logger.info(f"Code list length: {len(code_list)}")
193
+
194
+ if not code_list:
195
+ logger.warning("No valid code list generated. Skipping audio conversion.")
196
+ return None
197
 
198
  progress(0.8, "Converting to audio...")
199
  audio_samples = redistribute_codes(code_list, snac_model)
200
 
201
+ if audio_samples is None:
202
+ logger.warning("Audio samples is None.")
203
+ return None
204
+
205
+ if len(audio_samples) == 0:
206
+ logger.warning("Audio samples is empty.")
207
+ return None
208
+
209
+ logger.info(f"Audio samples shape: {audio_samples.shape}")
210
  return (24000, audio_samples) # Return sample rate and audio
211
  except Exception as e:
212
+ logger.error(f"Error generating speech: {e}", exc_info=True)
213
  return None
214
 
215
  @spaces.GPU()