Spaces:
Running
Running
import gradio as gr | |
import os | |
import traceback | |
import torch | |
import gc | |
from huggingface_hub import hf_hub_download | |
import shutil | |
import spaces | |
try: | |
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH | |
except ImportError: | |
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" | |
LOCAL_MODEL_PATH = "./chatterbox_model_files" | |
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported successfully") | |
except ImportError as e: | |
print(f"Failed to import ChatterboxTTS: {e}") | |
chatterbox_available = False | |
model = None | |
def cleanup_gpu_memory(): | |
"""Clean up GPU memory to prevent CUDA errors.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
def safe_load_model(): | |
"""Safely load the model with proper error handling.""" | |
global model | |
if not chatterbox_available: | |
print("ERROR: Chatterbox TTS library not available") | |
return False | |
try: | |
# Clean up any existing GPU memory | |
cleanup_gpu_memory() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Loading model on device: {device}") | |
# Try different loading methods | |
try: | |
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) | |
print("β Model loaded successfully using from_local method.") | |
except Exception as e1: | |
print(f"from_local failed: {e1}") | |
try: | |
model = ChatterboxTTS.from_pretrained(device) | |
print("β Model loaded successfully with from_pretrained.") | |
except Exception as e2: | |
print(f"from_pretrained failed: {e2}") | |
# Manual loading as fallback | |
model = load_model_manually(device) | |
# Move model to device and set to eval mode | |
if model and hasattr(model, 'to'): | |
model = model.to(device) | |
if model and hasattr(model, 'eval'): | |
model.eval() | |
# Clean up after loading | |
cleanup_gpu_memory() | |
return True | |
except Exception as e: | |
print(f"ERROR: Failed to load model: {e}") | |
traceback.print_exc() | |
model = None | |
cleanup_gpu_memory() | |
return False | |
def load_model_manually(device): | |
"""Manual model loading with proper error handling.""" | |
import pathlib | |
import json | |
model_path = pathlib.Path(LOCAL_MODEL_PATH) | |
print("Manual loading with correct constructor signature...") | |
# Load components to CPU first | |
s3gen_path = model_path / "s3gen.pt" | |
ve_path = model_path / "ve.pt" | |
tokenizer_path = model_path / "tokenizer.json" | |
t3_cfg_path = model_path / "t3_cfg.pt" | |
s3gen = torch.load(s3gen_path, map_location='cpu') | |
ve = torch.load(ve_path, map_location='cpu') | |
t3_cfg = torch.load(t3_cfg_path, map_location='cpu') | |
with open(tokenizer_path, 'r') as f: | |
tokenizer_data = json.load(f) | |
try: | |
from chatterbox.models.tokenizers.tokenizer import EnTokenizer | |
tokenizer = EnTokenizer.from_dict(tokenizer_data) | |
except Exception: | |
tokenizer = tokenizer_data | |
# Create model instance | |
model = ChatterboxTTS( | |
t3=t3_cfg, | |
s3gen=s3gen, | |
ve=ve, | |
tokenizer=tokenizer, | |
device=device | |
) | |
print("β Model loaded successfully with manual constructor.") | |
return model | |
def download_model_files(): | |
"""Download model files with error handling.""" | |
print(f"Checking for model files in {LOCAL_MODEL_PATH}...") | |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) | |
for filename in MODEL_FILES: | |
local_path = os.path.join(LOCAL_MODEL_PATH, filename) | |
if not os.path.exists(local_path): | |
print(f"Downloading {filename} from {MODEL_REPO_ID}...") | |
try: | |
downloaded_path = hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=filename, | |
cache_dir="./cache", | |
force_download=False | |
) | |
shutil.copy2(downloaded_path, local_path) | |
print(f"β Downloaded and copied {filename}") | |
except Exception as e: | |
print(f"β Failed to download {filename}: {e}") | |
raise e | |
else: | |
print(f"β {filename} already exists locally") | |
print("All model files are ready!") | |
# Initialize model | |
if chatterbox_available: | |
try: | |
download_model_files() | |
safe_load_model() | |
except Exception as e: | |
print(f"ERROR during initialization: {e}") | |
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
"""Main voice cloning function with improved error handling.""" | |
# Input validation | |
if not chatterbox_available: | |
return None, "Error: Chatterbox TTS library not available. Please check installation." | |
if model is None: | |
return None, "Error: Model not loaded. Please check the logs for details." | |
if not text_to_speak or text_to_speak.strip() == "": | |
return None, "Error: Please enter some text to speak." | |
if reference_audio_path is None: | |
return None, "Error: Please upload a reference audio file (.wav or .mp3)." | |
try: | |
print(f"Processing request:") | |
print(f" Text length: {len(text_to_speak)} characters") | |
print(f" Audio: '{reference_audio_path}'") | |
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}") | |
# Clean GPU memory before generation | |
cleanup_gpu_memory() | |
# Set random seed if specified | |
if random_seed > 0: | |
torch.manual_seed(random_seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(random_seed) | |
# Check CUDA availability before generation | |
if torch.cuda.is_available(): | |
print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") | |
# Generate audio with error handling | |
try: | |
with torch.no_grad(): # Disable gradient computation | |
output_wav_data = model.generate( | |
text=text_to_speak, | |
audio_prompt_path=reference_audio_path, | |
exaggeration=exaggeration, | |
cfg_weight=cfg_pace, | |
temperature=temperature | |
) | |
except RuntimeError as e: | |
if "CUDA" in str(e) or "out of memory" in str(e): | |
print(f"CUDA error during generation: {e}") | |
# Try to recover by cleaning memory and retrying | |
cleanup_gpu_memory() | |
try: | |
with torch.no_grad(): | |
output_wav_data = model.generate( | |
text=text_to_speak, | |
audio_prompt_path=reference_audio_path, | |
exaggeration=exaggeration, | |
cfg_weight=cfg_pace, | |
temperature=temperature | |
) | |
print("β Recovery successful after memory cleanup") | |
except Exception as retry_error: | |
print(f"β Recovery failed: {retry_error}") | |
return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment." | |
else: | |
raise e | |
# Get sample rate | |
try: | |
sample_rate = model.sr | |
except: | |
sample_rate = 24000 | |
# Process output | |
if isinstance(output_wav_data, str): | |
result = output_wav_data | |
else: | |
import numpy as np | |
if hasattr(output_wav_data, 'cpu'): | |
output_wav_data = output_wav_data.cpu().numpy() | |
if output_wav_data.ndim > 1: | |
output_wav_data = output_wav_data.squeeze() | |
result = (sample_rate, output_wav_data) | |
# Clean up GPU memory after generation | |
cleanup_gpu_memory() | |
if torch.cuda.is_available(): | |
print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") | |
print("β Audio generated successfully") | |
return result, "Success: Audio generated successfully!" | |
except Exception as e: | |
print(f"ERROR during audio generation: {e}") | |
traceback.print_exc() | |
# Clean up on error | |
cleanup_gpu_memory() | |
# Provide specific error messages | |
error_msg = str(e) | |
if "CUDA" in error_msg or "device-side assert" in error_msg: | |
return None, f"CUDA error: {error_msg}. This is usually a temporary GPU issue. Please try again in a moment." | |
elif "out of memory" in error_msg: | |
return None, f"GPU memory error: {error_msg}. Please try with shorter text or try again later." | |
else: | |
return None, f"Error during audio generation: {error_msg}. Check logs for more details." | |
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
"""API wrapper with improved error handling.""" | |
import requests | |
import tempfile | |
import os | |
import base64 | |
temp_audio_path = None | |
try: | |
# Handle different audio input formats | |
if reference_audio_url.startswith('data:audio'): | |
header, encoded = reference_audio_url.split(',', 1) | |
audio_data = base64.b64decode(encoded) | |
ext = '.mp3' if 'mp3' in header else '.wav' | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(audio_data) | |
temp_audio_path = temp_file.name | |
elif reference_audio_url.startswith('http'): | |
response = requests.get(reference_audio_url, timeout=30) | |
response.raise_for_status() | |
ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav' | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(response.content) | |
temp_audio_path = temp_file.name | |
else: | |
temp_audio_path = reference_audio_url | |
# Generate audio | |
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) | |
return audio_output, status | |
except Exception as e: | |
print(f"API Error: {e}") | |
return None, f"API Error: {str(e)}" | |
finally: | |
# Clean up temporary file | |
if temp_audio_path and temp_audio_path != reference_audio_url: | |
try: | |
os.unlink(temp_audio_path) | |
except: | |
pass | |
# Rest of your Gradio interface code remains the same... | |
def main(): | |
print("Starting Advanced Gradio interface...") | |
# Your existing Gradio interface code here | |
pass | |
if __name__ == "__main__": | |
main() | |