import gradio as gr import os import traceback import torch import gc from huggingface_hub import hf_hub_download import shutil import spaces try: from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH except ImportError: MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" LOCAL_MODEL_PATH = "./chatterbox_model_files" MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] try: from chatterbox.tts import ChatterboxTTS chatterbox_available = True print("Chatterbox TTS imported successfully") except ImportError as e: print(f"Failed to import ChatterboxTTS: {e}") chatterbox_available = False model = None def cleanup_gpu_memory(): """Clean up GPU memory to prevent CUDA errors.""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def safe_load_model(): """Safely load the model with proper error handling.""" global model if not chatterbox_available: print("ERROR: Chatterbox TTS library not available") return False try: # Clean up any existing GPU memory cleanup_gpu_memory() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Try different loading methods try: model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) print("✓ Model loaded successfully using from_local method.") except Exception as e1: print(f"from_local failed: {e1}") try: model = ChatterboxTTS.from_pretrained(device) print("✓ Model loaded successfully with from_pretrained.") except Exception as e2: print(f"from_pretrained failed: {e2}") # Manual loading as fallback model = load_model_manually(device) # Move model to device and set to eval mode if model and hasattr(model, 'to'): model = model.to(device) if model and hasattr(model, 'eval'): model.eval() # Clean up after loading cleanup_gpu_memory() return True except Exception as e: print(f"ERROR: Failed to load model: {e}") traceback.print_exc() model = None cleanup_gpu_memory() return False def load_model_manually(device): """Manual model loading with proper error handling.""" import pathlib import json model_path = pathlib.Path(LOCAL_MODEL_PATH) print("Manual loading with correct constructor signature...") # Load components to CPU first s3gen_path = model_path / "s3gen.pt" ve_path = model_path / "ve.pt" tokenizer_path = model_path / "tokenizer.json" t3_cfg_path = model_path / "t3_cfg.pt" s3gen = torch.load(s3gen_path, map_location='cpu') ve = torch.load(ve_path, map_location='cpu') t3_cfg = torch.load(t3_cfg_path, map_location='cpu') with open(tokenizer_path, 'r') as f: tokenizer_data = json.load(f) try: from chatterbox.models.tokenizers.tokenizer import EnTokenizer tokenizer = EnTokenizer.from_dict(tokenizer_data) except Exception: tokenizer = tokenizer_data # Create model instance model = ChatterboxTTS( t3=t3_cfg, s3gen=s3gen, ve=ve, tokenizer=tokenizer, device=device ) print("✓ Model loaded successfully with manual constructor.") return model def download_model_files(): """Download model files with error handling.""" print(f"Checking for model files in {LOCAL_MODEL_PATH}...") os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) for filename in MODEL_FILES: local_path = os.path.join(LOCAL_MODEL_PATH, filename) if not os.path.exists(local_path): print(f"Downloading {filename} from {MODEL_REPO_ID}...") try: downloaded_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=filename, cache_dir="./cache", force_download=False ) shutil.copy2(downloaded_path, local_path) print(f"✓ Downloaded and copied {filename}") except Exception as e: print(f"✗ Failed to download {filename}: {e}") raise e else: print(f"✓ {filename} already exists locally") print("All model files are ready!") # Initialize model if chatterbox_available: try: download_model_files() safe_load_model() except Exception as e: print(f"ERROR during initialization: {e}") @spaces.GPU def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): """Main voice cloning function with improved error handling.""" # Input validation if not chatterbox_available: return None, "Error: Chatterbox TTS library not available. Please check installation." if model is None: return None, "Error: Model not loaded. Please check the logs for details." if not text_to_speak or text_to_speak.strip() == "": return None, "Error: Please enter some text to speak." if reference_audio_path is None: return None, "Error: Please upload a reference audio file (.wav or .mp3)." try: print(f"Processing request:") print(f" Text length: {len(text_to_speak)} characters") print(f" Audio: '{reference_audio_path}'") print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}") # Clean GPU memory before generation cleanup_gpu_memory() # Set random seed if specified if random_seed > 0: torch.manual_seed(random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(random_seed) # Check CUDA availability before generation if torch.cuda.is_available(): print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") # Generate audio with error handling try: with torch.no_grad(): # Disable gradient computation output_wav_data = model.generate( text=text_to_speak, audio_prompt_path=reference_audio_path, exaggeration=exaggeration, cfg_weight=cfg_pace, temperature=temperature ) except RuntimeError as e: if "CUDA" in str(e) or "out of memory" in str(e): print(f"CUDA error during generation: {e}") # Try to recover by cleaning memory and retrying cleanup_gpu_memory() try: with torch.no_grad(): output_wav_data = model.generate( text=text_to_speak, audio_prompt_path=reference_audio_path, exaggeration=exaggeration, cfg_weight=cfg_pace, temperature=temperature ) print("✓ Recovery successful after memory cleanup") except Exception as retry_error: print(f"✗ Recovery failed: {retry_error}") return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment." else: raise e # Get sample rate try: sample_rate = model.sr except: sample_rate = 24000 # Process output if isinstance(output_wav_data, str): result = output_wav_data else: import numpy as np if hasattr(output_wav_data, 'cpu'): output_wav_data = output_wav_data.cpu().numpy() if output_wav_data.ndim > 1: output_wav_data = output_wav_data.squeeze() result = (sample_rate, output_wav_data) # Clean up GPU memory after generation cleanup_gpu_memory() if torch.cuda.is_available(): print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") print("✓ Audio generated successfully") return result, "Success: Audio generated successfully!" except Exception as e: print(f"ERROR during audio generation: {e}") traceback.print_exc() # Clean up on error cleanup_gpu_memory() # Provide specific error messages error_msg = str(e) if "CUDA" in error_msg or "device-side assert" in error_msg: return None, f"CUDA error: {error_msg}. This is usually a temporary GPU issue. Please try again in a moment." elif "out of memory" in error_msg: return None, f"GPU memory error: {error_msg}. Please try with shorter text or try again later." else: return None, f"Error during audio generation: {error_msg}. Check logs for more details." def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): """API wrapper with improved error handling.""" import requests import tempfile import os import base64 temp_audio_path = None try: # Handle different audio input formats if reference_audio_url.startswith('data:audio'): header, encoded = reference_audio_url.split(',', 1) audio_data = base64.b64decode(encoded) ext = '.mp3' if 'mp3' in header else '.wav' with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: temp_file.write(audio_data) temp_audio_path = temp_file.name elif reference_audio_url.startswith('http'): response = requests.get(reference_audio_url, timeout=30) response.raise_for_status() ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav' with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: temp_file.write(response.content) temp_audio_path = temp_file.name else: temp_audio_path = reference_audio_url # Generate audio audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) return audio_output, status except Exception as e: print(f"API Error: {e}") return None, f"API Error: {str(e)}" finally: # Clean up temporary file if temp_audio_path and temp_audio_path != reference_audio_url: try: os.unlink(temp_audio_path) except: pass # Rest of your Gradio interface code remains the same... def main(): print("Starting Advanced Gradio interface...") # Your existing Gradio interface code here pass if __name__ == "__main__": main()