voice_cloning / app.py
ramimu's picture
Update app.py
d52290c verified
raw
history blame
21.4 kB
import gradio as gr
import os
import traceback # For detailed error logging
import torch
from huggingface_hub import hf_hub_download
import shutil
# Import configuration
try:
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH
except ImportError:
# Fallback configuration if config.py is not found
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model"
LOCAL_MODEL_PATH = "./chatterbox_model_files"
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"]
# Try importing chatterbox with better error handling
try:
from chatterbox.tts import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported successfully")
# Inspect the ChatterboxTTS class to understand its API
import inspect
print(f"ChatterboxTTS methods: {[method for method in dir(ChatterboxTTS) if not method.startswith('_')]}")
# Check constructor signature
try:
sig = inspect.signature(ChatterboxTTS.__init__)
print(f"ChatterboxTTS.__init__ signature: {sig}")
except:
pass
# Check from_local signature if it exists
if hasattr(ChatterboxTTS, 'from_local'):
try:
sig = inspect.signature(ChatterboxTTS.from_local)
print(f"ChatterboxTTS.from_local signature: {sig}")
except:
pass
# Check from_pretrained signature if it exists
if hasattr(ChatterboxTTS, 'from_pretrained'):
try:
sig = inspect.signature(ChatterboxTTS.from_pretrained)
print(f"ChatterboxTTS.from_pretrained signature: {sig}")
except:
pass
except ImportError as e:
print(f"Failed to import ChatterboxTTS: {e}")
print("Trying alternative import...")
try:
import chatterbox
from chatterbox import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported with alternative method")
except ImportError as e2:
print(f"Alternative import also failed: {e2}")
chatterbox_available = False
# --- Global Model Variable ---
model = None
def download_model_files():
"""Download model files from Hugging Face Hub if they don't exist locally"""
print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
# Create model directory if it doesn't exist
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
for filename in MODEL_FILES:
local_path = os.path.join(LOCAL_MODEL_PATH, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {MODEL_REPO_ID}...")
try:
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=filename,
cache_dir="./cache",
force_download=False # Use cache if available
)
# Copy to our local model path
shutil.copy2(downloaded_path, local_path)
print(f"βœ“ Downloaded and copied {filename}")
except Exception as e:
print(f"βœ— Failed to download {filename}: {e}")
raise e
else:
print(f"βœ“ {filename} already exists locally")
print("All model files are ready!")
# --- Load the Model ---
if chatterbox_available:
print("Downloading model files from Hugging Face Hub...")
try:
download_model_files()
except Exception as e:
print(f"ERROR: Failed to download model files: {e}")
print("Model loading will fail without these files.")
print(f"Attempting to load Chatterbox model from local directory: {LOCAL_MODEL_PATH}")
if not os.path.exists(LOCAL_MODEL_PATH):
print(f"ERROR: Local model directory not found at {LOCAL_MODEL_PATH}")
print("Please ensure the model files were downloaded successfully.")
else:
print(f"Contents of {LOCAL_MODEL_PATH}: {os.listdir(LOCAL_MODEL_PATH)}")
try:
# Load the model from the specified local directory
# Set device to CPU or CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Based on API inspection:
# ChatterboxTTS.from_local signature: (ckpt_dir, device) -> 'ChatterboxTTS'
# ChatterboxTTS.from_pretrained signature: (device) -> 'ChatterboxTTS'
try:
# Method 1: Use from_local with correct signature (ckpt_dir, device)
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device)
print("Chatterbox model loaded successfully using from_local method.")
except Exception as e1:
print(f"from_local attempt failed: {e1}")
try:
# Method 2: Use from_pretrained with device only
model = ChatterboxTTS.from_pretrained(device)
print("Chatterbox model loaded successfully with from_pretrained.")
except Exception as e2:
print(f"from_pretrained failed: {e2}")
try:
# Method 3: Manual loading with correct constructor signature
# ChatterboxTTS.__init__ signature: (self, t3, s3gen, ve, tokenizer, device, conds=None)
import pathlib
import json
model_path = pathlib.Path(LOCAL_MODEL_PATH)
print(f"Manual loading with correct constructor signature...")
# Load all components
s3gen_path = model_path / "s3gen.pt"
ve_path = model_path / "ve.pt"
tokenizer_path = model_path / "tokenizer.json"
t3_cfg_path = model_path / "t3_cfg.pt"
print(f" Loading s3gen from: {s3gen_path}")
s3gen = torch.load(s3gen_path, map_location=torch.device('cpu'))
print(f" Loading ve from: {ve_path}")
ve = torch.load(ve_path, map_location=torch.device('cpu'))
print(f" Loading t3_cfg from: {t3_cfg_path}")
t3_cfg = torch.load(t3_cfg_path, map_location=torch.device('cpu'))
print(f" Loading tokenizer from: {tokenizer_path}")
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
# The tokenizer might need to be instantiated as a proper object
# Let's try to use the ChatterboxTTS internal tokenizer class
try:
from chatterbox.models.tokenizers.tokenizer import EnTokenizer
tokenizer = EnTokenizer.from_dict(tokenizer_data)
print(" Created EnTokenizer from JSON data")
except Exception as tok_error:
print(f" Could not create EnTokenizer: {tok_error}")
tokenizer = tokenizer_data # Use raw data as fallback
print(" Creating ChatterboxTTS instance with correct signature...")
# Constructor signature: (self, t3, s3gen, ve, tokenizer, device, conds=None)
model = ChatterboxTTS(
t3=t3_cfg,
s3gen=s3gen,
ve=ve,
tokenizer=tokenizer,
device=device
)
print("Chatterbox model loaded successfully with manual constructor.")
except Exception as e3:
print(f"Manual loading failed: {e3}")
print(f"Detailed error: {str(e3)}")
# Last resort: try with different parameter orders
try:
print("Trying alternative parameter order...")
model = ChatterboxTTS(
s3gen, ve, tokenizer, t3_cfg, device
)
print("Chatterbox model loaded with alternative parameter order.")
except Exception as e4:
print(f"Alternative parameter order failed: {e4}")
raise e3
except Exception as e:
print(f"ERROR: Failed to load Chatterbox model from local directory: {e}")
print("Detailed error trace:")
traceback.print_exc() # Prints the full traceback to the Hugging Face Space logs
model = None # Ensure model is None if loading fails
else:
print("ERROR: Chatterbox TTS library not available")
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
if not chatterbox_available:
return None, "Error: Chatterbox TTS library not available. Please check installation."
if model is None:
return None, "Error: Model not loaded. Please check the logs for details."
if not text_to_speak or text_to_speak.strip() == "":
return None, "Error: Please enter some text to speak."
if reference_audio_path is None:
return None, "Error: Please upload a reference audio file (.wav or .mp3)."
try:
print(f"Received request:")
print(f" Text: '{text_to_speak}'")
print(f" Audio: '{reference_audio_path}'")
print(f" Exaggeration: {exaggeration}")
print(f" CFG/Pace: {cfg_pace}")
print(f" Random Seed: {random_seed}")
print(f" Temperature: {temperature}")
# Set random seed if specified
if random_seed > 0:
import torch
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
# Use the correct ChatterboxTTS generate method signature with advanced parameters
output_wav_data = model.generate(
text=text_to_speak,
audio_prompt_path=reference_audio_path,
exaggeration=exaggeration, # Controls how much the voice characteristics are emphasized
cfg_weight=cfg_pace, # Classifier-free guidance weight (pace)
temperature=temperature # Controls randomness in generation
)
# Get the sample rate from the model
try:
sample_rate = model.sr # ChatterboxTTS uses 'sr' attribute
except:
sample_rate = 24000 # Default fallback
print(f"Audio generated successfully. Output data type: {type(output_wav_data)}, Sample rate: {sample_rate}")
# Handle different output formats
if isinstance(output_wav_data, str):
# If it's a file path, return the path
return output_wav_data, "Success: Audio generated successfully!"
else:
# If it's numpy array or tensor, return with sample rate
import numpy as np
if hasattr(output_wav_data, 'cpu'):
# Convert tensor to numpy if needed
output_wav_data = output_wav_data.cpu().numpy()
# Ensure it's the right shape for Gradio (1D array)
if output_wav_data.ndim > 1:
output_wav_data = output_wav_data.squeeze()
return (sample_rate, output_wav_data), "Success: Audio generated successfully!"
except Exception as e:
print(f"ERROR: Failed during audio generation: {e}")
print("Detailed error trace for audio generation:")
traceback.print_exc() # Prints the full traceback
return None, f"Error during audio generation: {str(e)}. Check logs for more details."
# --- API Endpoint Function ---
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
"""
API version of clone_voice that accepts URL or base64 audio data
"""
import requests
import tempfile
import os
import base64
# Handle different audio input formats
temp_audio_path = None
try:
if reference_audio_url.startswith('data:audio'):
# Handle base64 encoded audio
header, encoded = reference_audio_url.split(',', 1)
audio_data = base64.b64decode(encoded)
# Determine file extension from MIME type
if 'mp3' in header:
ext = '.mp3'
elif 'wav' in header:
ext = '.wav'
else:
ext = '.wav' # Default
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
elif reference_audio_url.startswith('http'):
# Download audio from URL
response = requests.get(reference_audio_url)
response.raise_for_status()
# Determine extension from URL or content type
if reference_audio_url.endswith('.mp3'):
ext = '.mp3'
elif reference_audio_url.endswith('.wav'):
ext = '.wav'
else:
ext = '.wav' # Default
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(response.content)
temp_audio_path = temp_file.name
else:
# Assume it's a local file path
temp_audio_path = reference_audio_url
# Call the main clone_voice function
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
# Clean up temporary file if we created one
if temp_audio_path and temp_audio_path != reference_audio_url:
try:
os.unlink(temp_audio_path)
except:
pass
return audio_output, status
except Exception as e:
if temp_audio_path and temp_audio_path != reference_audio_url:
try:
os.unlink(temp_audio_path)
except:
pass
return None, f"API Error: {str(e)}"
# --- Define Gradio Interface ---
# --- Define Gradio Interface ---
with gr.Blocks(title="Advanced Chatterbox Voice Cloning", theme=gr.themes.Soft()) as iface:
gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
with gr.Row():
with gr.Column(scale=2):
# Main inputs
text_input = gr.Textbox(
label="Text to Speak",
placeholder="Enter the text you want the cloned voice to say...",
lines=3
)
audio_input = gr.Audio(
type="filepath",
label="Reference Audio (Upload a short .wav or .mp3 clip)",
sources=["upload", "microphone"]
)
# Advanced controls in an accordion
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
with gr.Row():
exaggeration = gr.Slider(
minimum=0.25,
maximum=1.0,
value=0.6,
step=0.05,
label="Exaggeration",
info="Controls voice characteristic emphasis (0.5 = neutral, higher = more exaggerated)"
)
cfg_pace = gr.Slider(
minimum=0.2,
maximum=1.0,
value=0.3,
step=0.05,
label="CFG/Pace",
info="Classifier-free guidance weight (affects generation quality and pace)"
)
with gr.Row():
random_seed = gr.Number(
value=0,
label="Random Seed",
info="Set to 0 for random results, or use a specific number for reproducible outputs",
precision=0
)
temperature = gr.Slider(
minimum=0.05,
maximum=2.0,
value=0.6,
step=0.05,
label="Temperature",
info="Controls randomness in generation (lower = more consistent, higher = more varied)"
)
# Generate button
generate_btn = gr.Button("🎡 Generate Voice Clone", variant="primary", size="lg")
with gr.Column(scale=1):
# Outputs
audio_output = gr.Audio(
label="Generated Audio",
type="numpy",
interactive=False
)
status_output = gr.Textbox(
label="Status",
interactive=False,
lines=2
)
# API Information
with gr.Accordion("πŸ”Œ API Usage", open=False):
gr.Markdown("""
### Using this as an API endpoint
You can use this Hugging Face Space as an API endpoint in your applications:
**Endpoint URL:** `https://your-username-voice-cloning.hf.space/api/predict`
**Example Python code:**
```python
import requests
import base64
# Encode your audio file
with open("reference_audio.wav", "rb") as f:
audio_data = base64.b64encode(f.read()).decode()
audio_url = f"data:audio/wav;base64,{audio_data}"
# API request
response = requests.post(
"https://your-username-voice-cloning.hf.space/api/predict",
json={
"data": [
"Hello, this is my cloned voice!", # text
audio_url, # reference audio (base64 or URL)
0.6, # exaggeration
0.3, # cfg_pace
0, # random_seed
0.6 # temperature
]
}
)
```
**Parameters:**
- `text_to_speak`: Text to synthesize
- `reference_audio`: Base64 encoded audio or URL
- `exaggeration`: Voice emphasis (0.25-1.0, default: 0.6)
- `cfg_pace`: Generation guidance (0.2-1.0, default: 0.3)
- `random_seed`: Reproducibility seed (0 for random, default: 0)
- `temperature`: Generation randomness (0.05-2.0, default: 0.6)
""")
# Examples
with gr.Accordion("πŸ“ Examples", open=False):
gr.Examples(
examples=[
["Hello, this is a test of the voice cloning system.", None, 0.5, 0.5, 0, 0.8],
["The quick brown fox jumps over the lazy dog.", None, 0.7, 0.3, 42, 0.6],
["Welcome to our AI voice cloning service. We hope you enjoy the experience!", None, 0.4, 0.7, 123, 1.0]
],
inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature],
outputs=[audio_output, status_output],
fn=clone_voice,
cache_examples=False
)
# Connect the generate button
generate_btn.click(
fn=clone_voice,
inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature],
outputs=[audio_output, status_output],
api_name="clone_voice" # This enables API access
)
# --- Launch the Gradio App ---
def main():
print("Starting Advanced Gradio interface...")
# Launch with specific configuration for API access and avoid manifest issues
iface.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860, # Explicit port
show_error=True, # Show detailed errors
quiet=False, # Show startup logs
favicon_path=None, # Disable favicon to avoid 404
share=False, # Set to True if you want a public link
auth=None, # Add authentication if needed: ("username", "password")
app_kwargs={
"docs_url": "/docs", # Enable API docs at /docs
"redoc_url": "/redoc" # Enable alternative docs at /redoc
}
)
if __name__ == "__main__":
main()