voice_cloning / app.py
ramimu's picture
Update app.py
91d6893 verified
raw
history blame
11.3 kB
import gradio as gr
import os
import traceback
import torch
import gc
from huggingface_hub import hf_hub_download
import shutil
import spaces
try:
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH
except ImportError:
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model"
LOCAL_MODEL_PATH = "./chatterbox_model_files"
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"]
try:
from chatterbox.tts import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported successfully")
except ImportError as e:
print(f"Failed to import ChatterboxTTS: {e}")
chatterbox_available = False
model = None
def cleanup_gpu_memory():
"""Clean up GPU memory to prevent CUDA errors."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def safe_load_model():
"""Safely load the model with proper error handling."""
global model
if not chatterbox_available:
print("ERROR: Chatterbox TTS library not available")
return False
try:
# Clean up any existing GPU memory
cleanup_gpu_memory()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Try different loading methods
try:
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device)
print("βœ“ Model loaded successfully using from_local method.")
except Exception as e1:
print(f"from_local failed: {e1}")
try:
model = ChatterboxTTS.from_pretrained(device)
print("βœ“ Model loaded successfully with from_pretrained.")
except Exception as e2:
print(f"from_pretrained failed: {e2}")
# Manual loading as fallback
model = load_model_manually(device)
# Move model to device and set to eval mode
if model and hasattr(model, 'to'):
model = model.to(device)
if model and hasattr(model, 'eval'):
model.eval()
# Clean up after loading
cleanup_gpu_memory()
return True
except Exception as e:
print(f"ERROR: Failed to load model: {e}")
traceback.print_exc()
model = None
cleanup_gpu_memory()
return False
def load_model_manually(device):
"""Manual model loading with proper error handling."""
import pathlib
import json
model_path = pathlib.Path(LOCAL_MODEL_PATH)
print("Manual loading with correct constructor signature...")
# Load components to CPU first
s3gen_path = model_path / "s3gen.pt"
ve_path = model_path / "ve.pt"
tokenizer_path = model_path / "tokenizer.json"
t3_cfg_path = model_path / "t3_cfg.pt"
s3gen = torch.load(s3gen_path, map_location='cpu')
ve = torch.load(ve_path, map_location='cpu')
t3_cfg = torch.load(t3_cfg_path, map_location='cpu')
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
try:
from chatterbox.models.tokenizers.tokenizer import EnTokenizer
tokenizer = EnTokenizer.from_dict(tokenizer_data)
except Exception:
tokenizer = tokenizer_data
# Create model instance
model = ChatterboxTTS(
t3=t3_cfg,
s3gen=s3gen,
ve=ve,
tokenizer=tokenizer,
device=device
)
print("βœ“ Model loaded successfully with manual constructor.")
return model
def download_model_files():
"""Download model files with error handling."""
print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
for filename in MODEL_FILES:
local_path = os.path.join(LOCAL_MODEL_PATH, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {MODEL_REPO_ID}...")
try:
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=filename,
cache_dir="./cache",
force_download=False
)
shutil.copy2(downloaded_path, local_path)
print(f"βœ“ Downloaded and copied {filename}")
except Exception as e:
print(f"βœ— Failed to download {filename}: {e}")
raise e
else:
print(f"βœ“ {filename} already exists locally")
print("All model files are ready!")
# Initialize model
if chatterbox_available:
try:
download_model_files()
safe_load_model()
except Exception as e:
print(f"ERROR during initialization: {e}")
@spaces.GPU
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
"""Main voice cloning function with improved error handling."""
# Input validation
if not chatterbox_available:
return None, "Error: Chatterbox TTS library not available. Please check installation."
if model is None:
return None, "Error: Model not loaded. Please check the logs for details."
if not text_to_speak or text_to_speak.strip() == "":
return None, "Error: Please enter some text to speak."
if reference_audio_path is None:
return None, "Error: Please upload a reference audio file (.wav or .mp3)."
try:
print(f"Processing request:")
print(f" Text length: {len(text_to_speak)} characters")
print(f" Audio: '{reference_audio_path}'")
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}")
# Clean GPU memory before generation
cleanup_gpu_memory()
# Set random seed if specified
if random_seed > 0:
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
# Check CUDA availability before generation
if torch.cuda.is_available():
print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
# Generate audio with error handling
try:
with torch.no_grad(): # Disable gradient computation
output_wav_data = model.generate(
text=text_to_speak,
audio_prompt_path=reference_audio_path,
exaggeration=exaggeration,
cfg_weight=cfg_pace,
temperature=temperature
)
except RuntimeError as e:
if "CUDA" in str(e) or "out of memory" in str(e):
print(f"CUDA error during generation: {e}")
# Try to recover by cleaning memory and retrying
cleanup_gpu_memory()
try:
with torch.no_grad():
output_wav_data = model.generate(
text=text_to_speak,
audio_prompt_path=reference_audio_path,
exaggeration=exaggeration,
cfg_weight=cfg_pace,
temperature=temperature
)
print("βœ“ Recovery successful after memory cleanup")
except Exception as retry_error:
print(f"βœ— Recovery failed: {retry_error}")
return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment."
else:
raise e
# Get sample rate
try:
sample_rate = model.sr
except:
sample_rate = 24000
# Process output
if isinstance(output_wav_data, str):
result = output_wav_data
else:
import numpy as np
if hasattr(output_wav_data, 'cpu'):
output_wav_data = output_wav_data.cpu().numpy()
if output_wav_data.ndim > 1:
output_wav_data = output_wav_data.squeeze()
result = (sample_rate, output_wav_data)
# Clean up GPU memory after generation
cleanup_gpu_memory()
if torch.cuda.is_available():
print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
print("βœ“ Audio generated successfully")
return result, "Success: Audio generated successfully!"
except Exception as e:
print(f"ERROR during audio generation: {e}")
traceback.print_exc()
# Clean up on error
cleanup_gpu_memory()
# Provide specific error messages
error_msg = str(e)
if "CUDA" in error_msg or "device-side assert" in error_msg:
return None, f"CUDA error: {error_msg}. This is usually a temporary GPU issue. Please try again in a moment."
elif "out of memory" in error_msg:
return None, f"GPU memory error: {error_msg}. Please try with shorter text or try again later."
else:
return None, f"Error during audio generation: {error_msg}. Check logs for more details."
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
"""API wrapper with improved error handling."""
import requests
import tempfile
import os
import base64
temp_audio_path = None
try:
# Handle different audio input formats
if reference_audio_url.startswith('data:audio'):
header, encoded = reference_audio_url.split(',', 1)
audio_data = base64.b64decode(encoded)
ext = '.mp3' if 'mp3' in header else '.wav'
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
elif reference_audio_url.startswith('http'):
response = requests.get(reference_audio_url, timeout=30)
response.raise_for_status()
ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav'
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(response.content)
temp_audio_path = temp_file.name
else:
temp_audio_path = reference_audio_url
# Generate audio
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
return audio_output, status
except Exception as e:
print(f"API Error: {e}")
return None, f"API Error: {str(e)}"
finally:
# Clean up temporary file
if temp_audio_path and temp_audio_path != reference_audio_url:
try:
os.unlink(temp_audio_path)
except:
pass
# Rest of your Gradio interface code remains the same...
def main():
print("Starting Advanced Gradio interface...")
# Your existing Gradio interface code here
pass
if __name__ == "__main__":
main()