voice_cloning / app.py
ramimu's picture
Update app.py
73ab9c0 verified
raw
history blame
17.1 kB
import gradio as gr
import os
import traceback
import torch
from huggingface_hub import hf_hub_download
import shutil
import spaces
try:
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH
except ImportError:
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model"
LOCAL_MODEL_PATH = "./chatterbox_model_files"
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"]
try:
from chatterbox.tts import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported successfully")
import inspect
print(f"ChatterboxTTS methods: {[method for method in dir(ChatterboxTTS) if not method.startswith('_')]}")
try:
sig = inspect.signature(ChatterboxTTS.__init__)
print(f"ChatterboxTTS.__init__ signature: {sig}")
except:
pass
if hasattr(ChatterboxTTS, 'from_local'):
try:
sig = inspect.signature(ChatterboxTTS.from_local)
print(f"ChatterboxTTS.from_local signature: {sig}")
except:
pass
if hasattr(ChatterboxTTS, 'from_pretrained'):
try:
sig = inspect.signature(ChatterboxTTS.from_pretrained)
print(f"ChatterboxTTS.from_pretrained signature: {sig}")
except:
pass
except ImportError as e:
print(f"Failed to import ChatterboxTTS: {e}")
print("Trying alternative import...")
try:
import chatterbox
from chatterbox import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported with alternative method")
except ImportError as e2:
print(f"Alternative import also failed: {e2}")
chatterbox_available = False
model = None
def download_model_files():
print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
for filename in MODEL_FILES:
local_path = os.path.join(LOCAL_MODEL_PATH, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {MODEL_REPO_ID}...")
try:
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=filename,
cache_dir="./cache",
force_download=False
)
shutil.copy2(downloaded_path, local_path)
print(f"βœ“ Downloaded and copied {filename}")
except Exception as e:
print(f"βœ— Failed to download {filename}: {e}")
raise e
else:
print(f"βœ“ {filename} already exists locally")
print("All model files are ready!")
if chatterbox_available:
print("Downloading model files from Hugging Face Hub...")
try:
download_model_files()
except Exception as e:
print(f"ERROR: Failed to download model files: {e}")
print("Model loading will fail without these files.")
print(f"Attempting to load Chatterbox model from local directory: {LOCAL_MODEL_PATH}")
if not os.path.exists(LOCAL_MODEL_PATH):
print(f"ERROR: Local model directory not found at {LOCAL_MODEL_PATH}")
print("Please ensure the model files were downloaded successfully.")
else:
print(f"Contents of {LOCAL_MODEL_PATH}: {os.listdir(LOCAL_MODEL_PATH)}")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device)
print("Chatterbox model loaded successfully using from_local method.")
except Exception as e1:
print(f"from_local attempt failed: {e1}")
try:
model = ChatterboxTTS.from_pretrained(device)
print("Chatterbox model loaded successfully with from_pretrained.")
except Exception as e2:
print(f"from_pretrained failed: {e2}")
try:
import pathlib
import json
model_path = pathlib.Path(LOCAL_MODEL_PATH)
print(f"Manual loading with correct constructor signature...")
s3gen_path = model_path / "s3gen.pt"
ve_path = model_path / "ve.pt"
tokenizer_path = model_path / "tokenizer.json"
t3_cfg_path = model_path / "t3_cfg.pt"
print(f" Loading s3gen from: {s3gen_path}")
s3gen = torch.load(s3gen_path, map_location=torch.device('cpu'))
print(f" Loading ve from: {ve_path}")
ve = torch.load(ve_path, map_location=torch.device('cpu'))
print(f" Loading t3_cfg from: {t3_cfg_path}")
t3_cfg = torch.load(t3_cfg_path, map_location=torch.device('cpu'))
print(f" Loading tokenizer from: {tokenizer_path}")
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
try:
from chatterbox.models.tokenizers.tokenizer import EnTokenizer
tokenizer = EnTokenizer.from_dict(tokenizer_data)
print(" Created EnTokenizer from JSON data")
except Exception as tok_error:
print(f" Could not create EnTokenizer: {tok_error}")
tokenizer = tokenizer_data
print(" Creating ChatterboxTTS instance with correct signature...")
model = ChatterboxTTS(
t3=t3_cfg,
s3gen=s3gen,
ve=ve,
tokenizer=tokenizer,
device=device
)
print("Chatterbox model loaded successfully with manual constructor.")
except Exception as e3:
print(f"Manual loading failed: {e3}")
print(f"Detailed error: {str(e3)}")
try:
print("Trying alternative parameter order...")
model = ChatterboxTTS(
s3gen, ve, tokenizer, t3_cfg, device
)
print("Chatterbox model loaded with alternative parameter order.")
except Exception as e4:
print(f"Alternative parameter order failed: {e4}")
raise e3
except Exception as e:
print(f"ERROR: Failed to load Chatterbox model from local directory: {e}")
print("Detailed error trace:")
traceback.print_exc()
model = None
else:
print("ERROR: Chatterbox TTS library not available")
@spaces.GPU
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
if not chatterbox_available:
return None, "Error: Chatterbox TTS library not available. Please check installation."
if model is None:
return None, "Error: Model not loaded. Please check the logs for details."
if not text_to_speak or text_to_speak.strip() == "":
return None, "Error: Please enter some text to speak."
if reference_audio_path is None:
return None, "Error: Please upload a reference audio file (.wav or .mp3)."
try:
print(f"clone_voice function called:")
print(f" Text: '{text_to_speak}'")
print(f" Audio Path: '{reference_audio_path}'")
print(f" Exaggeration: {exaggeration}")
print(f" CFG/Pace: {cfg_pace}")
print(f" Random Seed: {random_seed}")
print(f" Temperature: {temperature}")
if random_seed > 0:
import torch
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
output_wav_data = model.generate(
text=text_to_speak,
audio_prompt_path=reference_audio_path,
exaggeration=exaggeration,
cfg_weight=cfg_pace,
temperature=temperature
)
try:
sample_rate = model.sr
except:
sample_rate = 24000
print(f"Audio generated successfully by clone_voice. Output data type: {type(output_wav_data)}, Sample rate: {sample_rate}")
if isinstance(output_wav_data, str):
return output_wav_data, "Success: Audio generated successfully!"
else:
import numpy as np
if hasattr(output_wav_data, 'cpu'):
output_wav_data = output_wav_data.cpu().numpy()
if output_wav_data.ndim > 1:
output_wav_data = output_wav_data.squeeze()
return (sample_rate, output_wav_data), "Success: Audio generated successfully!"
except Exception as e:
print(f"ERROR: Failed during audio generation in clone_voice: {e}")
print("Detailed error trace for audio generation in clone_voice:")
traceback.print_exc()
return None, f"Error during audio generation: {str(e)}. Check logs for more details."
# Updated clone_voice_api function with detailed logging
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
import requests
import tempfile
import os
import base64
temp_audio_path = None
try:
print(f"API call received by clone_voice_api:")
print(f" Text: {text_to_speak}")
print(f" Audio URL type: {type(reference_audio_url)}")
print(f" Audio URL preview: {str(reference_audio_url)[:100]}...")
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}")
if isinstance(reference_audio_url, str) and reference_audio_url.startswith('data:audio'):
print("Processing base64 audio data...")
header, encoded = reference_audio_url.split(',', 1)
audio_data = base64.b64decode(encoded)
print(f"Decoded audio data size: {len(audio_data)} bytes")
if 'mp3' in header:
ext = '.mp3'
elif 'wav' in header:
ext = '.wav'
else:
ext = '.wav'
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
print(f"Created temporary audio file from base64: {temp_audio_path}")
elif isinstance(reference_audio_url, str) and reference_audio_url.startswith('http'):
print("Processing HTTP audio URL...")
response = requests.get(reference_audio_url)
response.raise_for_status()
if reference_audio_url.endswith('.mp3'):
ext = '.mp3'
elif reference_audio_url.endswith('.wav'):
ext = '.wav'
else:
ext = '.wav' # Default
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(response.content)
temp_audio_path = temp_file.name
print(f"Created temporary audio file from URL: {temp_audio_path}")
elif isinstance(reference_audio_url, str) and os.path.exists(reference_audio_url):
print("Using direct file path provided as string...")
temp_audio_path = reference_audio_url
else:
# This case might occur if Gradio passes a TemporaryFileWrapper or similar
if hasattr(reference_audio_url, 'name'): # Check if it's a file-like object from Gradio
temp_audio_path = reference_audio_url.name
print(f"Using file path from Gradio object: {temp_audio_path}")
else:
print(f"Warning: Unrecognized audio input type or path: {reference_audio_url}. Assuming it's a direct path.")
temp_audio_path = str(reference_audio_url) # Fallback, attempt to use as path
if not temp_audio_path or not os.path.exists(temp_audio_path):
raise ValueError(f"Failed to obtain a valid audio file path from input: {reference_audio_url}")
print(f"Calling core clone_voice function with audio path: {temp_audio_path}")
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
print(f"clone_voice returned: {type(audio_output)}, {status}")
# Clean up temporary file only if we created one from base64 or URL
if temp_audio_path and isinstance(reference_audio_url, str) and \
(reference_audio_url.startswith('data:audio') or reference_audio_url.startswith('http')):
try:
os.unlink(temp_audio_path)
print(f"Cleaned up temporary file: {temp_audio_path}")
except Exception as e:
print(f"Failed to clean up temp file {temp_audio_path}: {e}")
return audio_output, status
except Exception as e:
print(f"ERROR in clone_voice_api: {e}")
import traceback # Ensure traceback is imported here if not globally
traceback.print_exc()
# Attempt to clean up temporary file in case of error too
if temp_audio_path and isinstance(reference_audio_url, str) and \
(reference_audio_url.startswith('data:audio') or reference_audio_url.startswith('http')):
try:
if os.path.exists(temp_audio_path): # Check existence before unlinking
os.unlink(temp_audio_path)
print(f"Cleaned up temporary file after error: {temp_audio_path}")
except Exception as e_clean:
print(f"Failed to clean up temp file {temp_audio_path} after error: {e_clean}")
return None, f"API Error: {str(e)}"
def main():
print("Starting Advanced Gradio interface...")
iface = gr.Interface(
fn=clone_voice, # The UI and default Gradio API will use clone_voice directly
inputs=[
gr.Textbox(
label="Text to Speak",
placeholder="Enter the text you want the cloned voice to say...",
lines=3
),
gr.Audio(
type="filepath", # Gradio handles file upload/mic and provides a filepath
label="Reference Audio (Upload a short .wav or .mp3 clip)",
sources=["upload", "microphone"]
),
gr.Slider(
minimum=0.25,
maximum=1.0,
value=0.6,
step=0.05,
label="Exaggeration",
info="Controls voice characteristic emphasis (0.5 = neutral, higher = more exaggerated)"
),
gr.Slider(
minimum=0.2,
maximum=1.0,
value=0.3,
step=0.05,
label="CFG/Pace",
info="Classifier-free guidance weight (affects generation quality and pace)"
),
gr.Number(
value=0,
label="Random Seed",
info="Set to 0 for random results, or use a specific number for reproducible outputs",
precision=0
),
gr.Slider(
minimum=0.05,
maximum=2.0,
value=0.6,
step=0.05,
label="Temperature",
info="Controls randomness in generation (lower = more consistent, higher = more varied)"
)
],
outputs=[
gr.Audio(label="Generated Audio", type="numpy"),
gr.Textbox(label="Status", lines=2)
],
title="πŸŽ™οΈ Advanced Chatterbox Voice Cloning",
description="Clone any voice using advanced AI technology with fine-tuned controls.",
examples=[
["Hello, this is a test of the voice cloning system.", None, 0.5, 0.5, 0, 0.8],
["The quick brown fox jumps over the lazy dog.", None, 0.7, 0.3, 42, 0.6],
["Welcome to our AI voice cloning service. We hope you enjoy the experience!", None, 0.4, 0.7, 123, 1.0]
],
api_name="clone_voice" # Add this line!
)
iface.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
quiet=False,
favicon_path=None,
share=False, # Set to True if you want a public link from your local machine
auth=None
# app_kwargs for FastAPI specific settings are not directly used by gr.Interface.launch
# but if you were embedding in FastAPI, you'd pass them to FastAPI app.
)
if __name__ == "__main__":
main()