Update app.py
Browse files
app.py
CHANGED
@@ -166,17 +166,13 @@ def redistribute_codes(code_list, snac_model):
|
|
166 |
return None
|
167 |
|
168 |
@spaces.GPU()
|
169 |
-
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens
|
170 |
if not text.strip():
|
171 |
-
logger.warning("Empty text input. Skipping speech generation.")
|
172 |
return None
|
173 |
|
174 |
try:
|
175 |
-
progress(0.1, "Processing text...")
|
176 |
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
|
177 |
-
logger.info(f"Input shape: {input_ids.shape}")
|
178 |
|
179 |
-
progress(0.3, "Generating speech tokens...")
|
180 |
with torch.no_grad():
|
181 |
generated_ids = model.generate(
|
182 |
input_ids=input_ids,
|
@@ -189,28 +185,10 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
|
|
189 |
num_return_sequences=1,
|
190 |
eos_token_id=128258,
|
191 |
)
|
192 |
-
logger.info(f"Generated shape: {generated_ids.shape}")
|
193 |
|
194 |
-
progress(0.6, "Processing speech tokens...")
|
195 |
code_list = parse_output(generated_ids)
|
196 |
-
logger.info(f"Code list length: {len(code_list)}")
|
197 |
-
|
198 |
-
if not code_list:
|
199 |
-
logger.warning("No valid code list generated. Skipping audio conversion.")
|
200 |
-
return None
|
201 |
-
|
202 |
-
progress(0.8, "Converting to audio...")
|
203 |
audio_samples = redistribute_codes(code_list, snac_model)
|
204 |
|
205 |
-
if audio_samples is None:
|
206 |
-
logger.warning("Audio samples is None.")
|
207 |
-
return None
|
208 |
-
|
209 |
-
if len(audio_samples) == 0:
|
210 |
-
logger.warning("Audio samples is empty.")
|
211 |
-
return None
|
212 |
-
|
213 |
-
logger.info(f"Audio samples shape: {audio_samples.shape}")
|
214 |
return (24000, audio_samples) # Return sample rate and audio
|
215 |
except Exception as e:
|
216 |
logger.error(f"Error generating speech: {e}", exc_info=True)
|
@@ -224,13 +202,10 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
|
224 |
|
225 |
for i, line in enumerate(lines):
|
226 |
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
audio_segments.append(audio)
|
232 |
-
except Exception as e:
|
233 |
-
logger.error(f"Error processing audio segment: {str(e)}")
|
234 |
|
235 |
if not audio_segments:
|
236 |
logger.warning("No valid audio segments were generated.")
|
|
|
166 |
return None
|
167 |
|
168 |
@spaces.GPU()
|
169 |
+
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens):
|
170 |
if not text.strip():
|
|
|
171 |
return None
|
172 |
|
173 |
try:
|
|
|
174 |
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
|
|
|
175 |
|
|
|
176 |
with torch.no_grad():
|
177 |
generated_ids = model.generate(
|
178 |
input_ids=input_ids,
|
|
|
185 |
num_return_sequences=1,
|
186 |
eos_token_id=128258,
|
187 |
)
|
|
|
188 |
|
|
|
189 |
code_list = parse_output(generated_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
audio_samples = redistribute_codes(code_list, snac_model)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
return (24000, audio_samples) # Return sample rate and audio
|
193 |
except Exception as e:
|
194 |
logger.error(f"Error generating speech: {e}", exc_info=True)
|
|
|
202 |
|
203 |
for i, line in enumerate(lines):
|
204 |
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
|
205 |
+
result = generate_speech(line, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200)
|
206 |
+
if result is not None:
|
207 |
+
sample_rate, audio = result
|
208 |
+
audio_segments.append(audio)
|
|
|
|
|
|
|
209 |
|
210 |
if not audio_segments:
|
211 |
logger.warning("No valid audio segments were generated.")
|