Update app.py
Browse files
app.py
CHANGED
@@ -136,35 +136,41 @@ def parse_output(generated_ids):
|
|
136 |
return code_lists[0]
|
137 |
|
138 |
def redistribute_codes(code_list, snac_model):
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
159 |
|
160 |
@spaces.GPU()
|
161 |
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
|
162 |
if not text.strip():
|
|
|
163 |
return None
|
164 |
|
165 |
try:
|
166 |
progress(0.1, "Processing text...")
|
167 |
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
|
|
|
168 |
|
169 |
progress(0.3, "Generating speech tokens...")
|
170 |
with torch.no_grad():
|
@@ -179,16 +185,31 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
|
|
179 |
num_return_sequences=1,
|
180 |
eos_token_id=128258,
|
181 |
)
|
|
|
182 |
|
183 |
progress(0.6, "Processing speech tokens...")
|
184 |
code_list = parse_output(generated_ids)
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
progress(0.8, "Converting to audio...")
|
187 |
audio_samples = redistribute_codes(code_list, snac_model)
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
return (24000, audio_samples) # Return sample rate and audio
|
190 |
except Exception as e:
|
191 |
-
|
192 |
return None
|
193 |
|
194 |
@spaces.GPU()
|
|
|
136 |
return code_lists[0]
|
137 |
|
138 |
def redistribute_codes(code_list, snac_model):
|
139 |
+
try:
|
140 |
+
device = next(snac_model.parameters()).device
|
141 |
+
|
142 |
+
layer_1, layer_2, layer_3 = [], [], []
|
143 |
+
for i in range((len(code_list)+1)//7):
|
144 |
+
layer_1.append(code_list[7*i])
|
145 |
+
layer_2.append(code_list[7*i+1]-4096)
|
146 |
+
layer_3.append(code_list[7*i+2]-(2*4096))
|
147 |
+
layer_3.append(code_list[7*i+3]-(3*4096))
|
148 |
+
layer_2.append(code_list[7*i+4]-(4*4096))
|
149 |
+
layer_3.append(code_list[7*i+5]-(5*4096))
|
150 |
+
layer_3.append(code_list[7*i+6]-(6*4096))
|
151 |
+
|
152 |
+
codes = [
|
153 |
+
torch.tensor(layer_1, device=device).unsqueeze(0),
|
154 |
+
torch.tensor(layer_2, device=device).unsqueeze(0),
|
155 |
+
torch.tensor(layer_3, device=device).unsqueeze(0)
|
156 |
+
]
|
157 |
+
|
158 |
+
audio_hat = snac_model.decode(codes)
|
159 |
+
return audio_hat.detach().squeeze().cpu().numpy()
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error in redistribute_codes: {e}", exc_info=True)
|
162 |
+
return None
|
163 |
|
164 |
@spaces.GPU()
|
165 |
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
|
166 |
if not text.strip():
|
167 |
+
logger.warning("Empty text input. Skipping speech generation.")
|
168 |
return None
|
169 |
|
170 |
try:
|
171 |
progress(0.1, "Processing text...")
|
172 |
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
|
173 |
+
logger.info(f"Input shape: {input_ids.shape}")
|
174 |
|
175 |
progress(0.3, "Generating speech tokens...")
|
176 |
with torch.no_grad():
|
|
|
185 |
num_return_sequences=1,
|
186 |
eos_token_id=128258,
|
187 |
)
|
188 |
+
logger.info(f"Generated shape: {generated_ids.shape}")
|
189 |
|
190 |
progress(0.6, "Processing speech tokens...")
|
191 |
code_list = parse_output(generated_ids)
|
192 |
+
logger.info(f"Code list length: {len(code_list)}")
|
193 |
+
|
194 |
+
if not code_list:
|
195 |
+
logger.warning("No valid code list generated. Skipping audio conversion.")
|
196 |
+
return None
|
197 |
|
198 |
progress(0.8, "Converting to audio...")
|
199 |
audio_samples = redistribute_codes(code_list, snac_model)
|
200 |
|
201 |
+
if audio_samples is None:
|
202 |
+
logger.warning("Audio samples is None.")
|
203 |
+
return None
|
204 |
+
|
205 |
+
if len(audio_samples) == 0:
|
206 |
+
logger.warning("Audio samples is empty.")
|
207 |
+
return None
|
208 |
+
|
209 |
+
logger.info(f"Audio samples shape: {audio_samples.shape}")
|
210 |
return (24000, audio_samples) # Return sample rate and audio
|
211 |
except Exception as e:
|
212 |
+
logger.error(f"Error generating speech: {e}", exc_info=True)
|
213 |
return None
|
214 |
|
215 |
@spaces.GPU()
|