Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,664 Bytes
9d593b2 3dab9c0 9d593b2 3dab9c0 9d593b2 3dab9c0 9d593b2 3dab9c0 9d593b2 3dab9c0 9d593b2 3dab9c0 c1fca8e 3dab9c0 2e214c5 9d593b2 3dab9c0 9d593b2 2e214c5 9d593b2 3dab9c0 9d593b2 3dab9c0 2e214c5 9d593b2 3dab9c0 9d593b2 f094d94 58ffee2 2bd4215 58ffee2 9d593b2 c1fca8e 9d593b2 2e214c5 9d593b2 58ffee2 9d593b2 2e214c5 9d593b2 3dab9c0 fae012e 3dab9c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces # <<< IMPORT THIS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}") # Good to log this
# Global model variable to load only once if not using gr.State for model object
# global_model = None
def set_seed(seed: int):
torch.manual_seed(seed)
if DEVICE == "cuda": # Only seed cuda if available
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# Optional: Decorate model loading if it's done on first use within a GPU function
# However, it's often better to load the model once globally or manage with gr.State
# and ensure the function CALLING the model is decorated.
@spaces.GPU # <<< ADD THIS DECORATOR
def generate(model_obj, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw):
# It's better to load the model once, perhaps when the gr.State is initialized
# or globally, rather than checking `model_obj is None` on every call.
# For ZeroGPU, the decorated function handles the GPU context.
# Let's assume model_obj is passed correctly and is already on DEVICE
# or will be moved to DEVICE by ChatterboxTTS internally.
if model_obj is None:
print("Model is None, attempting to load...")
# This load should ideally happen on DEVICE and be efficient.
# If ChatterboxTTS.from_pretrained(DEVICE) is slow,
# this will happen inside the GPU-allocated time.
model_obj = ChatterboxTTS.from_pretrained(DEVICE)
print(f"Model loaded on device: {model_obj.device if hasattr(model_obj, 'device') else 'unknown'}")
if seed_num != 0:
set_seed(int(seed_num))
print(f"Generating audio for text: '{text}' on device: {DEVICE}")
wav = model_obj.generate(
text,
audio_prompt_path=audio_prompt_path,
exaggeration=exaggeration,
temperature=temperature,
cfg_weight=cfgw,
)
print("Audio generation complete.")
# The model state is passed back out, which is correct for gr.State
return (model_obj, (model_obj.sr, wav.squeeze(0).numpy()))
with gr.Blocks() as demo:
# To ensure model loads on app start and uses DEVICE correctly:
# Pre-load the model here if you want it loaded once globally for the Space instance.
# However, with gr.State(None) and loading in `generate`,
# the first user hitting "Generate" will trigger the load.
# This is fine if `ChatterboxTTS.from_pretrained(DEVICE)` correctly uses the GPU
# within the @spaces.GPU decorated `generate` function.
# For better clarity on model loading with ZeroGPU:
# Consider a dedicated function for loading the model that's called to initialize gr.State,
# or ensure the first call to `generate` handles it robustly within the GPU context.
# The current approach of loading if model_state is None within `generate` is okay
# as long as `generate` itself is decorated.
model_state = gr.State(None)
with gr.Row():
# ... (rest of your UI code is fine) ...
with gr.Column():
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate,
inputs=[
model_state,
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[model_state, audio_output],
)
# The share=True in launch() will give a UserWarning on Spaces, it's not needed.
# Hugging Face Spaces provides the public link automatically.
demo.queue(
max_size=50,
default_concurrency_limit=1, # Good for single model instance on GPU
).launch() # Removed share=True |