Spaces:
Running
Running
import gradio as gr | |
import os | |
import traceback | |
import torch | |
from huggingface_hub import hf_hub_download | |
import shutil | |
import spaces | |
try: | |
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH | |
except ImportError: | |
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" | |
LOCAL_MODEL_PATH = "./chatterbox_model_files" | |
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported successfully") | |
import inspect | |
print(f"ChatterboxTTS methods: {[method for method in dir(ChatterboxTTS) if not method.startswith('_')]}") | |
try: | |
sig = inspect.signature(ChatterboxTTS.__init__) | |
print(f"ChatterboxTTS.__init__ signature: {sig}") | |
except: | |
pass | |
if hasattr(ChatterboxTTS, 'from_local'): | |
try: | |
sig = inspect.signature(ChatterboxTTS.from_local) | |
print(f"ChatterboxTTS.from_local signature: {sig}") | |
except: | |
pass | |
if hasattr(ChatterboxTTS, 'from_pretrained'): | |
try: | |
sig = inspect.signature(ChatterboxTTS.from_pretrained) | |
print(f"ChatterboxTTS.from_pretrained signature: {sig}") | |
except: | |
pass | |
except ImportError as e: | |
print(f"Failed to import ChatterboxTTS: {e}") | |
print("Trying alternative import...") | |
try: | |
import chatterbox | |
from chatterbox import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported with alternative method") | |
except ImportError as e2: | |
print(f"Alternative import also failed: {e2}") | |
chatterbox_available = False | |
model = None | |
def download_model_files(): | |
print(f"Checking for model files in {LOCAL_MODEL_PATH}...") | |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) | |
for filename in MODEL_FILES: | |
local_path = os.path.join(LOCAL_MODEL_PATH, filename) | |
if not os.path.exists(local_path): | |
print(f"Downloading {filename} from {MODEL_REPO_ID}...") | |
try: | |
downloaded_path = hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=filename, | |
cache_dir="./cache", | |
force_download=False | |
) | |
shutil.copy2(downloaded_path, local_path) | |
print(f"β Downloaded and copied {filename}") | |
except Exception as e: | |
print(f"β Failed to download {filename}: {e}") | |
raise e | |
else: | |
print(f"β {filename} already exists locally") | |
print("All model files are ready!") | |
if chatterbox_available: | |
print("Downloading model files from Hugging Face Hub...") | |
try: | |
download_model_files() | |
except Exception as e: | |
print(f"ERROR: Failed to download model files: {e}") | |
print("Model loading will fail without these files.") | |
print(f"Attempting to load Chatterbox model from local directory: {LOCAL_MODEL_PATH}") | |
if not os.path.exists(LOCAL_MODEL_PATH): | |
print(f"ERROR: Local model directory not found at {LOCAL_MODEL_PATH}") | |
print("Please ensure the model files were downloaded successfully.") | |
else: | |
print(f"Contents of {LOCAL_MODEL_PATH}: {os.listdir(LOCAL_MODEL_PATH)}") | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
try: | |
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) | |
print("Chatterbox model loaded successfully using from_local method.") | |
except Exception as e1: | |
print(f"from_local attempt failed: {e1}") | |
try: | |
model = ChatterboxTTS.from_pretrained(device) | |
print("Chatterbox model loaded successfully with from_pretrained.") | |
except Exception as e2: | |
print(f"from_pretrained failed: {e2}") | |
try: | |
import pathlib | |
import json | |
model_path = pathlib.Path(LOCAL_MODEL_PATH) | |
print(f"Manual loading with correct constructor signature...") | |
s3gen_path = model_path / "s3gen.pt" | |
ve_path = model_path / "ve.pt" | |
tokenizer_path = model_path / "tokenizer.json" | |
t3_cfg_path = model_path / "t3_cfg.pt" | |
print(f" Loading s3gen from: {s3gen_path}") | |
s3gen = torch.load(s3gen_path, map_location=torch.device('cpu')) | |
print(f" Loading ve from: {ve_path}") | |
ve = torch.load(ve_path, map_location=torch.device('cpu')) | |
print(f" Loading t3_cfg from: {t3_cfg_path}") | |
t3_cfg = torch.load(t3_cfg_path, map_location=torch.device('cpu')) | |
print(f" Loading tokenizer from: {tokenizer_path}") | |
with open(tokenizer_path, 'r') as f: | |
tokenizer_data = json.load(f) | |
try: | |
from chatterbox.models.tokenizers.tokenizer import EnTokenizer | |
tokenizer = EnTokenizer.from_dict(tokenizer_data) | |
print(" Created EnTokenizer from JSON data") | |
except Exception as tok_error: | |
print(f" Could not create EnTokenizer: {tok_error}") | |
tokenizer = tokenizer_data | |
print(" Creating ChatterboxTTS instance with correct signature...") | |
model = ChatterboxTTS( | |
t3=t3_cfg, | |
s3gen=s3gen, | |
ve=ve, | |
tokenizer=tokenizer, | |
device=device | |
) | |
print("Chatterbox model loaded successfully with manual constructor.") | |
except Exception as e3: | |
print(f"Manual loading failed: {e3}") | |
print(f"Detailed error: {str(e3)}") | |
try: | |
print("Trying alternative parameter order...") | |
model = ChatterboxTTS( | |
s3gen, ve, tokenizer, t3_cfg, device | |
) | |
print("Chatterbox model loaded with alternative parameter order.") | |
except Exception as e4: | |
print(f"Alternative parameter order failed: {e4}") | |
raise e3 | |
except Exception as e: | |
print(f"ERROR: Failed to load Chatterbox model from local directory: {e}") | |
print("Detailed error trace:") | |
traceback.print_exc() | |
model = None | |
else: | |
print("ERROR: Chatterbox TTS library not available") | |
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
if not chatterbox_available: | |
return None, "Error: Chatterbox TTS library not available. Please check installation." | |
if model is None: | |
return None, "Error: Model not loaded. Please check the logs for details." | |
if not text_to_speak or text_to_speak.strip() == "": | |
return None, "Error: Please enter some text to speak." | |
if reference_audio_path is None: | |
return None, "Error: Please upload a reference audio file (.wav or .mp3)." | |
try: | |
print(f"clone_voice function called:") | |
print(f" Text: '{text_to_speak}'") | |
print(f" Audio Path: '{reference_audio_path}'") | |
print(f" Exaggeration: {exaggeration}") | |
print(f" CFG/Pace: {cfg_pace}") | |
print(f" Random Seed: {random_seed}") | |
print(f" Temperature: {temperature}") | |
if random_seed > 0: | |
import torch | |
torch.manual_seed(random_seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(random_seed) | |
output_wav_data = model.generate( | |
text=text_to_speak, | |
audio_prompt_path=reference_audio_path, | |
exaggeration=exaggeration, | |
cfg_weight=cfg_pace, | |
temperature=temperature | |
) | |
try: | |
sample_rate = model.sr | |
except: | |
sample_rate = 24000 | |
print(f"Audio generated successfully by clone_voice. Output data type: {type(output_wav_data)}, Sample rate: {sample_rate}") | |
if isinstance(output_wav_data, str): | |
return output_wav_data, "Success: Audio generated successfully!" | |
else: | |
import numpy as np | |
if hasattr(output_wav_data, 'cpu'): | |
output_wav_data = output_wav_data.cpu().numpy() | |
if output_wav_data.ndim > 1: | |
output_wav_data = output_wav_data.squeeze() | |
return (sample_rate, output_wav_data), "Success: Audio generated successfully!" | |
except Exception as e: | |
print(f"ERROR: Failed during audio generation in clone_voice: {e}") | |
print("Detailed error trace for audio generation in clone_voice:") | |
traceback.print_exc() | |
return None, f"Error during audio generation: {str(e)}. Check logs for more details." | |
# Updated clone_voice_api function with detailed logging | |
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
import requests | |
import tempfile | |
import os | |
import base64 | |
temp_audio_path = None | |
try: | |
print(f"API call received by clone_voice_api:") | |
print(f" Text: {text_to_speak}") | |
print(f" Audio URL type: {type(reference_audio_url)}") | |
print(f" Audio URL preview: {str(reference_audio_url)[:100]}...") | |
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}") | |
if isinstance(reference_audio_url, str) and reference_audio_url.startswith('data:audio'): | |
print("Processing base64 audio data...") | |
header, encoded = reference_audio_url.split(',', 1) | |
audio_data = base64.b64decode(encoded) | |
print(f"Decoded audio data size: {len(audio_data)} bytes") | |
if 'mp3' in header: | |
ext = '.mp3' | |
elif 'wav' in header: | |
ext = '.wav' | |
else: | |
ext = '.wav' | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(audio_data) | |
temp_audio_path = temp_file.name | |
print(f"Created temporary audio file from base64: {temp_audio_path}") | |
elif isinstance(reference_audio_url, str) and reference_audio_url.startswith('http'): | |
print("Processing HTTP audio URL...") | |
response = requests.get(reference_audio_url) | |
response.raise_for_status() | |
if reference_audio_url.endswith('.mp3'): | |
ext = '.mp3' | |
elif reference_audio_url.endswith('.wav'): | |
ext = '.wav' | |
else: | |
ext = '.wav' # Default | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(response.content) | |
temp_audio_path = temp_file.name | |
print(f"Created temporary audio file from URL: {temp_audio_path}") | |
elif isinstance(reference_audio_url, str) and os.path.exists(reference_audio_url): | |
print("Using direct file path provided as string...") | |
temp_audio_path = reference_audio_url | |
else: | |
# This case might occur if Gradio passes a TemporaryFileWrapper or similar | |
if hasattr(reference_audio_url, 'name'): # Check if it's a file-like object from Gradio | |
temp_audio_path = reference_audio_url.name | |
print(f"Using file path from Gradio object: {temp_audio_path}") | |
else: | |
print(f"Warning: Unrecognized audio input type or path: {reference_audio_url}. Assuming it's a direct path.") | |
temp_audio_path = str(reference_audio_url) # Fallback, attempt to use as path | |
if not temp_audio_path or not os.path.exists(temp_audio_path): | |
raise ValueError(f"Failed to obtain a valid audio file path from input: {reference_audio_url}") | |
print(f"Calling core clone_voice function with audio path: {temp_audio_path}") | |
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) | |
print(f"clone_voice returned: {type(audio_output)}, {status}") | |
# Clean up temporary file only if we created one from base64 or URL | |
if temp_audio_path and isinstance(reference_audio_url, str) and \ | |
(reference_audio_url.startswith('data:audio') or reference_audio_url.startswith('http')): | |
try: | |
os.unlink(temp_audio_path) | |
print(f"Cleaned up temporary file: {temp_audio_path}") | |
except Exception as e: | |
print(f"Failed to clean up temp file {temp_audio_path}: {e}") | |
return audio_output, status | |
except Exception as e: | |
print(f"ERROR in clone_voice_api: {e}") | |
import traceback # Ensure traceback is imported here if not globally | |
traceback.print_exc() | |
# Attempt to clean up temporary file in case of error too | |
if temp_audio_path and isinstance(reference_audio_url, str) and \ | |
(reference_audio_url.startswith('data:audio') or reference_audio_url.startswith('http')): | |
try: | |
if os.path.exists(temp_audio_path): # Check existence before unlinking | |
os.unlink(temp_audio_path) | |
print(f"Cleaned up temporary file after error: {temp_audio_path}") | |
except Exception as e_clean: | |
print(f"Failed to clean up temp file {temp_audio_path} after error: {e_clean}") | |
return None, f"API Error: {str(e)}" | |
def main(): | |
print("Starting Advanced Gradio interface...") | |
iface = gr.Interface( | |
fn=clone_voice, # The UI and default Gradio API will use clone_voice directly | |
inputs=[ | |
gr.Textbox( | |
label="Text to Speak", | |
placeholder="Enter the text you want the cloned voice to say...", | |
lines=3 | |
), | |
gr.Audio( | |
type="filepath", # Gradio handles file upload/mic and provides a filepath | |
label="Reference Audio (Upload a short .wav or .mp3 clip)", | |
sources=["upload", "microphone"] | |
), | |
gr.Slider( | |
minimum=0.25, | |
maximum=1.0, | |
value=0.6, | |
step=0.05, | |
label="Exaggeration", | |
info="Controls voice characteristic emphasis (0.5 = neutral, higher = more exaggerated)" | |
), | |
gr.Slider( | |
minimum=0.2, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="CFG/Pace", | |
info="Classifier-free guidance weight (affects generation quality and pace)" | |
), | |
gr.Number( | |
value=0, | |
label="Random Seed", | |
info="Set to 0 for random results, or use a specific number for reproducible outputs", | |
precision=0 | |
), | |
gr.Slider( | |
minimum=0.05, | |
maximum=2.0, | |
value=0.6, | |
step=0.05, | |
label="Temperature", | |
info="Controls randomness in generation (lower = more consistent, higher = more varied)" | |
) | |
], | |
outputs=[ | |
gr.Audio(label="Generated Audio", type="numpy"), | |
gr.Textbox(label="Status", lines=2) | |
], | |
title="ποΈ Advanced Chatterbox Voice Cloning", | |
description="Clone any voice using advanced AI technology with fine-tuned controls.", | |
examples=[ | |
["Hello, this is a test of the voice cloning system.", None, 0.5, 0.5, 0, 0.8], | |
["The quick brown fox jumps over the lazy dog.", None, 0.7, 0.3, 42, 0.6], | |
["Welcome to our AI voice cloning service. We hope you enjoy the experience!", None, 0.4, 0.7, 123, 1.0] | |
], | |
api_name="clone_voice" # Add this line! | |
) | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
quiet=False, | |
favicon_path=None, | |
share=False, # Set to True if you want a public link from your local machine | |
auth=None | |
# app_kwargs for FastAPI specific settings are not directly used by gr.Interface.launch | |
# but if you were embedding in FastAPI, you'd pass them to FastAPI app. | |
) | |
if __name__ == "__main__": | |
main() |