Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import traceback # For detailed error logging | |
import torch | |
from huggingface_hub import hf_hub_download | |
import shutil | |
# Import configuration | |
try: | |
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH | |
except ImportError: | |
# Fallback configuration if config.py is not found | |
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" | |
LOCAL_MODEL_PATH = "./chatterbox_model_files" | |
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] | |
# Try importing chatterbox with better error handling | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported successfully") | |
# Inspect the ChatterboxTTS class to understand its API | |
import inspect | |
print(f"ChatterboxTTS methods: {[method for method in dir(ChatterboxTTS) if not method.startswith('_')]}") | |
# Check constructor signature | |
try: | |
sig = inspect.signature(ChatterboxTTS.__init__) | |
print(f"ChatterboxTTS.__init__ signature: {sig}") | |
except: | |
pass | |
# Check from_local signature if it exists | |
if hasattr(ChatterboxTTS, 'from_local'): | |
try: | |
sig = inspect.signature(ChatterboxTTS.from_local) | |
print(f"ChatterboxTTS.from_local signature: {sig}") | |
except: | |
pass | |
# Check from_pretrained signature if it exists | |
if hasattr(ChatterboxTTS, 'from_pretrained'): | |
try: | |
sig = inspect.signature(ChatterboxTTS.from_pretrained) | |
print(f"ChatterboxTTS.from_pretrained signature: {sig}") | |
except: | |
pass | |
except ImportError as e: | |
print(f"Failed to import ChatterboxTTS: {e}") | |
print("Trying alternative import...") | |
try: | |
import chatterbox | |
from chatterbox import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported with alternative method") | |
except ImportError as e2: | |
print(f"Alternative import also failed: {e2}") | |
chatterbox_available = False | |
# --- Global Model Variable --- | |
model = None | |
def download_model_files(): | |
"""Download model files from Hugging Face Hub if they don't exist locally""" | |
print(f"Checking for model files in {LOCAL_MODEL_PATH}...") | |
# Create model directory if it doesn't exist | |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) | |
for filename in MODEL_FILES: | |
local_path = os.path.join(LOCAL_MODEL_PATH, filename) | |
if not os.path.exists(local_path): | |
print(f"Downloading {filename} from {MODEL_REPO_ID}...") | |
try: | |
downloaded_path = hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=filename, | |
cache_dir="./cache", | |
force_download=False # Use cache if available | |
) | |
# Copy to our local model path | |
shutil.copy2(downloaded_path, local_path) | |
print(f"β Downloaded and copied {filename}") | |
except Exception as e: | |
print(f"β Failed to download {filename}: {e}") | |
raise e | |
else: | |
print(f"β {filename} already exists locally") | |
print("All model files are ready!") | |
# --- Load the Model --- | |
if chatterbox_available: | |
print("Downloading model files from Hugging Face Hub...") | |
try: | |
download_model_files() | |
except Exception as e: | |
print(f"ERROR: Failed to download model files: {e}") | |
print("Model loading will fail without these files.") | |
print(f"Attempting to load Chatterbox model from local directory: {LOCAL_MODEL_PATH}") | |
if not os.path.exists(LOCAL_MODEL_PATH): | |
print(f"ERROR: Local model directory not found at {LOCAL_MODEL_PATH}") | |
print("Please ensure the model files were downloaded successfully.") | |
else: | |
print(f"Contents of {LOCAL_MODEL_PATH}: {os.listdir(LOCAL_MODEL_PATH)}") | |
try: | |
# Load the model from the specified local directory | |
# Set device to CPU or CUDA if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Based on API inspection: | |
# ChatterboxTTS.from_local signature: (ckpt_dir, device) -> 'ChatterboxTTS' | |
# ChatterboxTTS.from_pretrained signature: (device) -> 'ChatterboxTTS' | |
try: | |
# Method 1: Use from_local with correct signature (ckpt_dir, device) | |
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) | |
print("Chatterbox model loaded successfully using from_local method.") | |
except Exception as e1: | |
print(f"from_local attempt failed: {e1}") | |
try: | |
# Method 2: Use from_pretrained with device only | |
model = ChatterboxTTS.from_pretrained(device) | |
print("Chatterbox model loaded successfully with from_pretrained.") | |
except Exception as e2: | |
print(f"from_pretrained failed: {e2}") | |
try: | |
# Method 3: Manual loading with correct constructor signature | |
# ChatterboxTTS.__init__ signature: (self, t3, s3gen, ve, tokenizer, device, conds=None) | |
import pathlib | |
import json | |
model_path = pathlib.Path(LOCAL_MODEL_PATH) | |
print(f"Manual loading with correct constructor signature...") | |
# Load all components | |
s3gen_path = model_path / "s3gen.pt" | |
ve_path = model_path / "ve.pt" | |
tokenizer_path = model_path / "tokenizer.json" | |
t3_cfg_path = model_path / "t3_cfg.pt" | |
print(f" Loading s3gen from: {s3gen_path}") | |
s3gen = torch.load(s3gen_path, map_location=torch.device('cpu')) | |
print(f" Loading ve from: {ve_path}") | |
ve = torch.load(ve_path, map_location=torch.device('cpu')) | |
print(f" Loading t3_cfg from: {t3_cfg_path}") | |
t3_cfg = torch.load(t3_cfg_path, map_location=torch.device('cpu')) | |
print(f" Loading tokenizer from: {tokenizer_path}") | |
with open(tokenizer_path, 'r') as f: | |
tokenizer_data = json.load(f) | |
# The tokenizer might need to be instantiated as a proper object | |
# Let's try to use the ChatterboxTTS internal tokenizer class | |
try: | |
from chatterbox.models.tokenizers.tokenizer import EnTokenizer | |
tokenizer = EnTokenizer.from_dict(tokenizer_data) | |
print(" Created EnTokenizer from JSON data") | |
except Exception as tok_error: | |
print(f" Could not create EnTokenizer: {tok_error}") | |
tokenizer = tokenizer_data # Use raw data as fallback | |
print(" Creating ChatterboxTTS instance with correct signature...") | |
# Constructor signature: (self, t3, s3gen, ve, tokenizer, device, conds=None) | |
model = ChatterboxTTS( | |
t3=t3_cfg, | |
s3gen=s3gen, | |
ve=ve, | |
tokenizer=tokenizer, | |
device=device | |
) | |
print("Chatterbox model loaded successfully with manual constructor.") | |
except Exception as e3: | |
print(f"Manual loading failed: {e3}") | |
print(f"Detailed error: {str(e3)}") | |
# Last resort: try with different parameter orders | |
try: | |
print("Trying alternative parameter order...") | |
model = ChatterboxTTS( | |
s3gen, ve, tokenizer, t3_cfg, device | |
) | |
print("Chatterbox model loaded with alternative parameter order.") | |
except Exception as e4: | |
print(f"Alternative parameter order failed: {e4}") | |
raise e3 | |
except Exception as e: | |
print(f"ERROR: Failed to load Chatterbox model from local directory: {e}") | |
print("Detailed error trace:") | |
traceback.print_exc() # Prints the full traceback to the Hugging Face Space logs | |
model = None # Ensure model is None if loading fails | |
else: | |
print("ERROR: Chatterbox TTS library not available") | |
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
if not chatterbox_available: | |
return None, "Error: Chatterbox TTS library not available. Please check installation." | |
if model is None: | |
return None, "Error: Model not loaded. Please check the logs for details." | |
if not text_to_speak or text_to_speak.strip() == "": | |
return None, "Error: Please enter some text to speak." | |
if reference_audio_path is None: | |
return None, "Error: Please upload a reference audio file (.wav or .mp3)." | |
try: | |
print(f"Received request:") | |
print(f" Text: '{text_to_speak}'") | |
print(f" Audio: '{reference_audio_path}'") | |
print(f" Exaggeration: {exaggeration}") | |
print(f" CFG/Pace: {cfg_pace}") | |
print(f" Random Seed: {random_seed}") | |
print(f" Temperature: {temperature}") | |
# Set random seed if specified | |
if random_seed > 0: | |
import torch | |
torch.manual_seed(random_seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(random_seed) | |
# Use the correct ChatterboxTTS generate method signature with advanced parameters | |
output_wav_data = model.generate( | |
text=text_to_speak, | |
audio_prompt_path=reference_audio_path, | |
exaggeration=exaggeration, # Controls how much the voice characteristics are emphasized | |
cfg_weight=cfg_pace, # Classifier-free guidance weight (pace) | |
temperature=temperature # Controls randomness in generation | |
) | |
# Get the sample rate from the model | |
try: | |
sample_rate = model.sr # ChatterboxTTS uses 'sr' attribute | |
except: | |
sample_rate = 24000 # Default fallback | |
print(f"Audio generated successfully. Output data type: {type(output_wav_data)}, Sample rate: {sample_rate}") | |
# Handle different output formats | |
if isinstance(output_wav_data, str): | |
# If it's a file path, return the path | |
return output_wav_data, "Success: Audio generated successfully!" | |
else: | |
# If it's numpy array or tensor, return with sample rate | |
import numpy as np | |
if hasattr(output_wav_data, 'cpu'): | |
# Convert tensor to numpy if needed | |
output_wav_data = output_wav_data.cpu().numpy() | |
# Ensure it's the right shape for Gradio (1D array) | |
if output_wav_data.ndim > 1: | |
output_wav_data = output_wav_data.squeeze() | |
return (sample_rate, output_wav_data), "Success: Audio generated successfully!" | |
except Exception as e: | |
print(f"ERROR: Failed during audio generation: {e}") | |
print("Detailed error trace for audio generation:") | |
traceback.print_exc() # Prints the full traceback | |
return None, f"Error during audio generation: {str(e)}. Check logs for more details." | |
# --- API Endpoint Function --- | |
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
""" | |
API version of clone_voice that accepts URL or base64 audio data | |
""" | |
import requests | |
import tempfile | |
import os | |
import base64 | |
# Handle different audio input formats | |
temp_audio_path = None | |
try: | |
if reference_audio_url.startswith('data:audio'): | |
# Handle base64 encoded audio | |
header, encoded = reference_audio_url.split(',', 1) | |
audio_data = base64.b64decode(encoded) | |
# Determine file extension from MIME type | |
if 'mp3' in header: | |
ext = '.mp3' | |
elif 'wav' in header: | |
ext = '.wav' | |
else: | |
ext = '.wav' # Default | |
# Save to temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(audio_data) | |
temp_audio_path = temp_file.name | |
elif reference_audio_url.startswith('http'): | |
# Download audio from URL | |
response = requests.get(reference_audio_url) | |
response.raise_for_status() | |
# Determine extension from URL or content type | |
if reference_audio_url.endswith('.mp3'): | |
ext = '.mp3' | |
elif reference_audio_url.endswith('.wav'): | |
ext = '.wav' | |
else: | |
ext = '.wav' # Default | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(response.content) | |
temp_audio_path = temp_file.name | |
else: | |
# Assume it's a local file path | |
temp_audio_path = reference_audio_url | |
# Call the main clone_voice function | |
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) | |
# Clean up temporary file if we created one | |
if temp_audio_path and temp_audio_path != reference_audio_url: | |
try: | |
os.unlink(temp_audio_path) | |
except: | |
pass | |
return audio_output, status | |
except Exception as e: | |
if temp_audio_path and temp_audio_path != reference_audio_url: | |
try: | |
os.unlink(temp_audio_path) | |
except: | |
pass | |
return None, f"API Error: {str(e)}" | |
# --- Define Gradio Interface --- | |
# --- Define Gradio Interface --- | |
with gr.Blocks(title="Advanced Chatterbox Voice Cloning", theme=gr.themes.Soft()) as iface: | |
gr.Markdown("# ποΈ Advanced Chatterbox Voice Cloning") | |
gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Main inputs | |
text_input = gr.Textbox( | |
label="Text to Speak", | |
placeholder="Enter the text you want the cloned voice to say...", | |
lines=3 | |
) | |
audio_input = gr.Audio( | |
type="filepath", | |
label="Reference Audio (Upload a short .wav or .mp3 clip)", | |
sources=["upload", "microphone"] | |
) | |
# Advanced controls in an accordion | |
with gr.Accordion("π§ Advanced Settings", open=False): | |
with gr.Row(): | |
exaggeration = gr.Slider( | |
minimum=0.25, | |
maximum=1.0, | |
value=0.6, | |
step=0.05, | |
label="Exaggeration", | |
info="Controls voice characteristic emphasis (0.5 = neutral, higher = more exaggerated)" | |
) | |
cfg_pace = gr.Slider( | |
minimum=0.2, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="CFG/Pace", | |
info="Classifier-free guidance weight (affects generation quality and pace)" | |
) | |
with gr.Row(): | |
random_seed = gr.Number( | |
value=0, | |
label="Random Seed", | |
info="Set to 0 for random results, or use a specific number for reproducible outputs", | |
precision=0 | |
) | |
temperature = gr.Slider( | |
minimum=0.05, | |
maximum=2.0, | |
value=0.6, | |
step=0.05, | |
label="Temperature", | |
info="Controls randomness in generation (lower = more consistent, higher = more varied)" | |
) | |
# Generate button | |
generate_btn = gr.Button("π΅ Generate Voice Clone", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Outputs | |
audio_output = gr.Audio( | |
label="Generated Audio", | |
type="numpy", | |
interactive=False | |
) | |
status_output = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=2 | |
) | |
# API Information | |
with gr.Accordion("π API Usage", open=False): | |
gr.Markdown(""" | |
### Using this as an API endpoint | |
You can use this Hugging Face Space as an API endpoint in your applications: | |
**Endpoint URL:** `https://your-username-voice-cloning.hf.space/api/predict` | |
**Example Python code:** | |
```python | |
import requests | |
import base64 | |
# Encode your audio file | |
with open("reference_audio.wav", "rb") as f: | |
audio_data = base64.b64encode(f.read()).decode() | |
audio_url = f"data:audio/wav;base64,{audio_data}" | |
# API request | |
response = requests.post( | |
"https://your-username-voice-cloning.hf.space/api/predict", | |
json={ | |
"data": [ | |
"Hello, this is my cloned voice!", # text | |
audio_url, # reference audio (base64 or URL) | |
0.6, # exaggeration | |
0.3, # cfg_pace | |
0, # random_seed | |
0.6 # temperature | |
] | |
} | |
) | |
``` | |
**Parameters:** | |
- `text_to_speak`: Text to synthesize | |
- `reference_audio`: Base64 encoded audio or URL | |
- `exaggeration`: Voice emphasis (0.25-1.0, default: 0.6) | |
- `cfg_pace`: Generation guidance (0.2-1.0, default: 0.3) | |
- `random_seed`: Reproducibility seed (0 for random, default: 0) | |
- `temperature`: Generation randomness (0.05-2.0, default: 0.6) | |
""") | |
# Examples | |
with gr.Accordion("π Examples", open=False): | |
gr.Examples( | |
examples=[ | |
["Hello, this is a test of the voice cloning system.", None, 0.5, 0.5, 0, 0.8], | |
["The quick brown fox jumps over the lazy dog.", None, 0.7, 0.3, 42, 0.6], | |
["Welcome to our AI voice cloning service. We hope you enjoy the experience!", None, 0.4, 0.7, 123, 1.0] | |
], | |
inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature], | |
outputs=[audio_output, status_output], | |
fn=clone_voice, | |
cache_examples=False | |
) | |
# Connect the generate button | |
generate_btn.click( | |
fn=clone_voice, | |
inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature], | |
outputs=[audio_output, status_output], | |
api_name="clone_voice" # This enables API access | |
) | |
# --- Launch the Gradio App --- | |
def main(): | |
print("Starting Advanced Gradio interface...") | |
# Launch with specific configuration for API access and avoid manifest issues | |
iface.launch( | |
server_name="0.0.0.0", # Allow external connections | |
server_port=7860, # Explicit port | |
show_error=True, # Show detailed errors | |
quiet=False, # Show startup logs | |
favicon_path=None, # Disable favicon to avoid 404 | |
share=False, # Set to True if you want a public link | |
auth=None, # Add authentication if needed: ("username", "password") | |
app_kwargs={ | |
"docs_url": "/docs", # Enable API docs at /docs | |
"redoc_url": "/redoc" # Enable alternative docs at /redoc | |
} | |
) | |
if __name__ == "__main__": | |
main() |