ramimu commited on
Commit
ee39abc
Β·
verified Β·
1 Parent(s): 74dbc75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -38
app.py CHANGED
@@ -26,6 +26,10 @@ except ImportError as e:
26
  model = None
27
  model_loaded = False
28
 
 
 
 
 
29
  def download_model_files():
30
  """Download model files with error handling."""
31
  print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
@@ -65,7 +69,6 @@ def load_model_on_gpu():
65
  try:
66
  print("Loading model inside GPU context...")
67
 
68
- # Now we can safely use CUDA operations
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
  print(f"Loading model on device: {device}")
71
 
@@ -80,10 +83,8 @@ def load_model_on_gpu():
80
  print("βœ“ Model loaded successfully with from_pretrained.")
81
  except Exception as e2:
82
  print(f"from_pretrained failed: {e2}")
83
- # Manual loading as fallback
84
  model = load_model_manually(device)
85
 
86
- # Move model to device and set to eval mode
87
  if model and hasattr(model, 'to'):
88
  model = model.to(device)
89
  if model and hasattr(model, 'eval'):
@@ -108,7 +109,6 @@ def load_model_manually(device):
108
  model_path = pathlib.Path(LOCAL_MODEL_PATH)
109
  print("Manual loading with correct constructor signature...")
110
 
111
- # Load components to CPU first, then move to device
112
  s3gen_path = model_path / "s3gen.pt"
113
  ve_path = model_path / "ve.pt"
114
  tokenizer_path = model_path / "tokenizer.json"
@@ -127,7 +127,6 @@ def load_model_manually(device):
127
  except Exception:
128
  tokenizer = tokenizer_data
129
 
130
- # Create model instance
131
  model = ChatterboxTTS(
132
  t3=t3_cfg,
133
  s3gen=s3gen,
@@ -141,10 +140,35 @@ def load_model_manually(device):
141
 
142
  def cleanup_gpu_memory():
143
  """Clean up GPU memory - only call within GPU context."""
144
- if torch.cuda.is_available():
145
- torch.cuda.empty_cache()
146
- torch.cuda.synchronize()
147
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # Download model files during startup (CPU only)
150
  if chatterbox_available:
@@ -169,8 +193,16 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
169
  if reference_audio_path is None:
170
  return None, "Error: Please upload a reference audio file (.wav or .mp3)."
171
 
 
 
 
 
 
 
 
 
172
  try:
173
- # Load model if not already loaded (inside GPU context)
174
  if not model_loaded:
175
  print("Loading model for the first time...")
176
  if not load_model_on_gpu():
@@ -180,7 +212,9 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
180
  return None, "Error: Model not loaded. Please check the logs for details."
181
 
182
  print(f"Processing request:")
183
- print(f" Text length: {len(text_to_speak)} characters")
 
 
184
  print(f" Audio: '{reference_audio_path}'")
185
  print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}")
186
 
@@ -199,33 +233,19 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
199
 
200
  # Generate audio with error handling
201
  try:
202
- with torch.no_grad(): # Disable gradient computation to save memory
203
  output_wav_data = model.generate(
204
- text=text_to_speak,
205
  audio_prompt_path=reference_audio_path,
206
  exaggeration=exaggeration,
207
  cfg_weight=cfg_pace,
208
  temperature=temperature
209
  )
210
  except RuntimeError as e:
211
- if "CUDA" in str(e) or "out of memory" in str(e):
212
  print(f"CUDA error during generation: {e}")
213
- # Try to recover by cleaning memory and retrying
214
  cleanup_gpu_memory()
215
- try:
216
- with torch.no_grad():
217
- output_wav_data = model.generate(
218
- text=text_to_speak,
219
- audio_prompt_path=reference_audio_path,
220
- exaggeration=exaggeration,
221
- cfg_weight=cfg_pace,
222
- temperature=temperature
223
- )
224
- print("βœ“ Recovery successful after memory cleanup")
225
- except Exception as retry_error:
226
- print(f"βœ— Recovery failed: {retry_error}")
227
- cleanup_gpu_memory()
228
- return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment."
229
  else:
230
  raise e
231
 
@@ -253,7 +273,13 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
253
  print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
254
 
255
  print("βœ“ Audio generated successfully")
256
- return result, "Success: Audio generated successfully!"
 
 
 
 
 
 
257
 
258
  except Exception as e:
259
  print(f"ERROR during audio generation: {e}")
@@ -268,14 +294,14 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
268
  # Provide specific error messages
269
  error_msg = str(e)
270
  if "CUDA" in error_msg or "device-side assert" in error_msg:
271
- return None, f"CUDA error: {error_msg}. This is usually a temporary GPU issue. Please try again in a moment."
272
  elif "out of memory" in error_msg:
273
- return None, f"GPU memory error: {error_msg}. Please try with shorter text or try again later."
274
  else:
275
  return None, f"Error during audio generation: {error_msg}. Check logs for more details."
276
 
277
  def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
278
- """API wrapper function - this will call the GPU function."""
279
  import requests
280
  import tempfile
281
  import os
@@ -317,21 +343,28 @@ def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pa
317
  except:
318
  pass
319
 
320
- # Your existing Gradio interface code goes here...
321
  def main():
322
  print("Starting Advanced Gradio interface...")
323
 
324
- # Your existing Gradio interface code
325
  with gr.Blocks(title="πŸŽ™οΈ Advanced Chatterbox Voice Cloning") as demo:
326
  gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
327
  gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
 
 
 
 
 
 
 
 
328
 
329
  with gr.Row():
330
  with gr.Column(scale=2):
331
  text_input = gr.Textbox(
332
- label="Text to Speak",
333
  placeholder="Enter the text you want the cloned voice to say...",
334
- lines=3
 
335
  )
336
  audio_input = gr.Audio(
337
  type="filepath",
@@ -362,7 +395,7 @@ def main():
362
 
363
  with gr.Column(scale=1):
364
  audio_output = gr.Audio(label="Generated Audio", type="numpy")
365
- status_output = gr.Textbox(label="Status", lines=2)
366
 
367
  # Connect the interface
368
  generate_btn.click(
 
26
  model = None
27
  model_loaded = False
28
 
29
+ # Text length limits for the model
30
+ MAX_CHARS_PER_GENERATION = 1000 # Safe limit for single generation
31
+ MAX_CHARS_TOTAL = 5000 # Maximum we'll accept via API
32
+
33
  def download_model_files():
34
  """Download model files with error handling."""
35
  print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
 
69
  try:
70
  print("Loading model inside GPU context...")
71
 
 
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
  print(f"Loading model on device: {device}")
74
 
 
83
  print("βœ“ Model loaded successfully with from_pretrained.")
84
  except Exception as e2:
85
  print(f"from_pretrained failed: {e2}")
 
86
  model = load_model_manually(device)
87
 
 
88
  if model and hasattr(model, 'to'):
89
  model = model.to(device)
90
  if model and hasattr(model, 'eval'):
 
109
  model_path = pathlib.Path(LOCAL_MODEL_PATH)
110
  print("Manual loading with correct constructor signature...")
111
 
 
112
  s3gen_path = model_path / "s3gen.pt"
113
  ve_path = model_path / "ve.pt"
114
  tokenizer_path = model_path / "tokenizer.json"
 
127
  except Exception:
128
  tokenizer = tokenizer_data
129
 
 
130
  model = ChatterboxTTS(
131
  t3=t3_cfg,
132
  s3gen=s3gen,
 
140
 
141
  def cleanup_gpu_memory():
142
  """Clean up GPU memory - only call within GPU context."""
143
+ try:
144
+ if torch.cuda.is_available():
145
+ torch.cuda.empty_cache()
146
+ torch.cuda.synchronize()
147
+ gc.collect()
148
+ except Exception as e:
149
+ print(f"Warning: GPU cleanup failed: {e}")
150
+
151
+ def truncate_text_safely(text, max_chars=MAX_CHARS_PER_GENERATION):
152
+ """Truncate text to safe length while preserving sentence boundaries."""
153
+ if len(text) <= max_chars:
154
+ return text, False
155
+
156
+ # Find the last sentence ending before the limit
157
+ truncated = text[:max_chars]
158
+
159
+ # Look for sentence endings
160
+ for ending in ['. ', '! ', '? ']:
161
+ last_sentence = truncated.rfind(ending)
162
+ if last_sentence > max_chars * 0.7: # Don't truncate too aggressively
163
+ return text[:last_sentence + 1].strip(), True
164
+
165
+ # Fallback to word boundary
166
+ last_space = truncated.rfind(' ')
167
+ if last_space > max_chars * 0.8:
168
+ return text[:last_space].strip(), True
169
+
170
+ # Last resort: hard truncate
171
+ return truncated.strip(), True
172
 
173
  # Download model files during startup (CPU only)
174
  if chatterbox_available:
 
193
  if reference_audio_path is None:
194
  return None, "Error: Please upload a reference audio file (.wav or .mp3)."
195
 
196
+ # Check text length and truncate if necessary
197
+ original_length = len(text_to_speak)
198
+ if original_length > MAX_CHARS_TOTAL:
199
+ return None, f"Error: Text is too long ({original_length:,} characters). Maximum allowed is {MAX_CHARS_TOTAL:,} characters. Please use the chunked generation API for longer texts."
200
+
201
+ # Truncate to safe generation length
202
+ text_to_use, was_truncated = truncate_text_safely(text_to_speak, MAX_CHARS_PER_GENERATION)
203
+
204
  try:
205
+ # Load model if not already loaded
206
  if not model_loaded:
207
  print("Loading model for the first time...")
208
  if not load_model_on_gpu():
 
212
  return None, "Error: Model not loaded. Please check the logs for details."
213
 
214
  print(f"Processing request:")
215
+ print(f" Original text length: {original_length:,} characters")
216
+ print(f" Processing length: {len(text_to_use):,} characters")
217
+ print(f" Truncated: {was_truncated}")
218
  print(f" Audio: '{reference_audio_path}'")
219
  print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}")
220
 
 
233
 
234
  # Generate audio with error handling
235
  try:
236
+ with torch.no_grad():
237
  output_wav_data = model.generate(
238
+ text=text_to_use,
239
  audio_prompt_path=reference_audio_path,
240
  exaggeration=exaggeration,
241
  cfg_weight=cfg_pace,
242
  temperature=temperature
243
  )
244
  except RuntimeError as e:
245
+ if "CUDA" in str(e) or "out of memory" in str(e) or "device-side assert" in str(e):
246
  print(f"CUDA error during generation: {e}")
 
247
  cleanup_gpu_memory()
248
+ return None, f"CUDA error: Text may be too long for single generation. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API for longer content."
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  else:
250
  raise e
251
 
 
273
  print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
274
 
275
  print("βœ“ Audio generated successfully")
276
+
277
+ # Prepare success message
278
+ success_msg = "Success: Audio generated successfully!"
279
+ if was_truncated:
280
+ success_msg += f" Note: Text was truncated from {original_length:,} to {len(text_to_use):,} characters for optimal generation. Use the chunked generation API for longer texts."
281
+
282
+ return result, success_msg
283
 
284
  except Exception as e:
285
  print(f"ERROR during audio generation: {e}")
 
294
  # Provide specific error messages
295
  error_msg = str(e)
296
  if "CUDA" in error_msg or "device-side assert" in error_msg:
297
+ return None, f"CUDA error: {error_msg}. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API."
298
  elif "out of memory" in error_msg:
299
+ return None, f"GPU memory error: {error_msg}. Please try with shorter text."
300
  else:
301
  return None, f"Error during audio generation: {error_msg}. Check logs for more details."
302
 
303
  def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
304
+ """API wrapper function."""
305
  import requests
306
  import tempfile
307
  import os
 
343
  except:
344
  pass
345
 
 
346
  def main():
347
  print("Starting Advanced Gradio interface...")
348
 
 
349
  with gr.Blocks(title="πŸŽ™οΈ Advanced Chatterbox Voice Cloning") as demo:
350
  gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
351
  gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
352
+
353
+ # Add warning about text length
354
+ gr.Markdown(f"""
355
+ **⚠️ Text Length Limits:**
356
+ - **Single Generation**: Up to {MAX_CHARS_PER_GENERATION:,} characters (optimal quality)
357
+ - **API Maximum**: Up to {MAX_CHARS_TOTAL:,} characters (may be truncated)
358
+ - **For longer texts**: Use the chunked generation API in your application
359
+ """)
360
 
361
  with gr.Row():
362
  with gr.Column(scale=2):
363
  text_input = gr.Textbox(
364
+ label=f"Text to Speak (max {MAX_CHARS_TOTAL:,} characters)",
365
  placeholder="Enter the text you want the cloned voice to say...",
366
+ lines=5,
367
+ max_lines=10
368
  )
369
  audio_input = gr.Audio(
370
  type="filepath",
 
395
 
396
  with gr.Column(scale=1):
397
  audio_output = gr.Audio(label="Generated Audio", type="numpy")
398
+ status_output = gr.Textbox(label="Status", lines=3)
399
 
400
  # Connect the interface
401
  generate_btn.click(