File size: 14,164 Bytes
5193c5e 3d2ce0a 8909289 4b72318 8909289 ecf63b4 3d2ce0a 844e3a3 5193c5e 3d2ce0a 80bd4c9 3d2ce0a 80bd4c9 3d2ce0a 5193c5e c0c6352 844e3a3 ecf63b4 5193c5e 8bdeb28 80bd4c9 5193c5e 8bdeb28 80bd4c9 814bea6 80bd4c9 8bdeb28 c0c6352 8bdeb28 c0c6352 8bdeb28 c0c6352 8bdeb28 80bd4c9 5193c5e 80bd4c9 5193c5e 731d214 5193c5e e173776 5193c5e 80bd4c9 5193c5e c0c6352 5193c5e c0c6352 5193c5e 8bdeb28 844e3a3 d1e3c74 844e3a3 80bd4c9 814bea6 80bd4c9 d1e3c74 844e3a3 5193c5e 80bd4c9 844e3a3 ecf63b4 8bdeb28 814bea6 844e3a3 ecf63b4 844e3a3 814bea6 844e3a3 814bea6 844e3a3 8bdeb28 844e3a3 814bea6 e173776 5193c5e 80bd4c9 8bdeb28 c0c6352 e173776 844e3a3 c0c6352 8bdeb28 c0c6352 8bdeb28 e173776 5193c5e 54096db 3d2ce0a 80bd4c9 3d2ce0a 80bd4c9 1f2a815 8bdeb28 3d2ce0a 8bdeb28 c0c6352 54096db 3d2ce0a 4b72318 8bdeb28 54096db 3d2ce0a 54096db c0c6352 8bdeb28 844e3a3 80bd4c9 ecf63b4 8909289 ecf63b4 c0c6352 ecf63b4 c0c6352 8bdeb28 844e3a3 c0c6352 3d2ce0a c0c6352 8bdeb28 80bd4c9 c0c6352 3d2ce0a c0c6352 80bd4c9 e173776 c0c6352 63d778a 3d2ce0a 63d778a 5193c5e 80bd4c9 8909289 54096db 25f78d4 c0c6352 8909289 5193c5e 844e3a3 5193c5e 54096db 3d2ce0a 80bd4c9 3d2ce0a 8bdeb28 3d2ce0a 844e3a3 80bd4c9 844e3a3 c0c6352 844e3a3 c0c6352 844e3a3 3d2ce0a 80bd4c9 c0c6352 814bea6 80bd4c9 c6ae943 54096db 844e3a3 c0c6352 54096db 3d2ce0a 80bd4c9 c0c6352 80bd4c9 c0c6352 80bd4c9 c0c6352 80bd4c9 c0c6352 fb01dbf 80bd4c9 c0c6352 80bd4c9 c0c6352 844e3a3 c0c6352 814bea6 8909289 3d2ce0a 814bea6 3d2ce0a 54096db c0c6352 54096db 844e3a3 54096db 8bdeb28 54096db c0c6352 8909289 c0c6352 8909289 54096db 844e3a3 54096db 3d2ce0a 80bd4c9 3d2ce0a 54096db 3d2ce0a 54096db 3d2ce0a 8909289 5193c5e 80bd4c9 8909289 80bd4c9 c0c6352 80bd4c9 c0c6352 3d2ce0a c0c6352 80bd4c9 c0c6352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
import os
import time
from functools import wraps
import spaces
from snac import SNAC
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, login
from dotenv import load_dotenv
load_dotenv()
# Rate limiting
last_request_time = {}
REQUEST_COOLDOWN = 30
def rate_limit(func):
@wraps(func)
def wrapper(*args, **kwargs):
user_id = "anonymous"
current_time = time.time()
if user_id in last_request_time:
time_since_last = current_time - last_request_time[user_id]
if time_since_last < REQUEST_COOLDOWN:
remaining = int(REQUEST_COOLDOWN - time_since_last)
gr.Warning(f"Please wait {remaining} seconds before next request.")
return None
last_request_time[user_id] = current_time
return func(*args, **kwargs)
return wrapper
# Get HF token from environment variables
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
print("SNAC model loaded successfully")
model_name = "mrrtmob/tts-khm-kore"
print(f"Downloading model files from {model_name}...")
# Download only model config and safetensors with token
snapshot_download(
repo_id=model_name,
token=hf_token,
allow_patterns=[
"config.json",
"*.safetensors",
"model.safetensors.index.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt"
],
ignore_patterns=[
"optimizer.pt",
"pytorch_model.bin",
"training_args.bin",
"scheduler.pt"
]
)
print("Model files downloaded successfully")
print("Loading main model...")
# Load model and tokenizer with token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
token=hf_token
)
model = model.to(device)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token
)
print(f"Khmer TTS model loaded successfully to {device}")
# Process text prompt
def process_prompt(prompt, voice, tokenizer, device):
prompt = f"{voice}: {prompt}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
# No padding needed for single input
attention_mask = torch.ones_like(modified_input_ids)
return modified_input_ids.to(device), attention_mask.to(device)
# Parse output tokens to audio
def parse_output(generated_ids):
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != token_to_remove]
processed_rows.append(masked_row)
code_lists = []
for row in processed_rows:
row_length = row.size(0)
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t - 128266 for t in trimmed_row]
code_lists.append(trimmed_row)
return code_lists[0] if code_lists else []
# Redistribute codes for audio generation
def redistribute_codes(code_list, snac_model):
if not code_list:
return None
device = next(snac_model.parameters()).device
layer_1 = []
layer_2 = []
layer_3 = []
for i in range((len(code_list)+1)//7):
if 7*i < len(code_list):
layer_1.append(code_list[7*i])
if 7*i+1 < len(code_list):
layer_2.append(code_list[7*i+1]-4096)
if 7*i+2 < len(code_list):
layer_3.append(code_list[7*i+2]-(2*4096))
if 7*i+3 < len(code_list):
layer_3.append(code_list[7*i+3]-(3*4096))
if 7*i+4 < len(code_list):
layer_2.append(code_list[7*i+4]-(4*4096))
if 7*i+5 < len(code_list):
layer_3.append(code_list[7*i+5]-(5*4096))
if 7*i+6 < len(code_list):
layer_3.append(code_list[7*i+6]-(6*4096))
if not layer_1:
return None
codes = [
torch.tensor(layer_1, device=device).unsqueeze(0),
torch.tensor(layer_2, device=device).unsqueeze(0),
torch.tensor(layer_3, device=device).unsqueeze(0)
]
audio_hat = snac_model.decode(codes)
return audio_hat.detach().squeeze().cpu().numpy()
# Simple character counter function (only called when needed)
def update_char_count(text):
"""Simple character counter - no text modification"""
count = len(text) if text else 0
return f"Characters: {count}/150"
# Main generation function with rate limiting
@rate_limit
@spaces.GPU(duration=45)
def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
if not text.strip():
gr.Warning("Please enter some text to generate speech.")
return None
# Check length and truncate if needed
if len(text) > 150:
text = text[:150]
gr.Warning("Text was truncated to 150 characters.")
try:
progress(0.1, "Processing text...")
print(f"Generating speech for text: {text[:50]}...")
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
progress(0.3, "Generating speech tokens...")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
eos_token_id=128258,
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id
)
progress(0.6, "Processing speech tokens...")
code_list = parse_output(generated_ids)
if not code_list:
gr.Warning("Failed to generate valid audio codes.")
return None
progress(0.8, "Converting to audio...")
audio_samples = redistribute_codes(code_list, snac_model)
if audio_samples is None:
gr.Warning("Failed to convert codes to audio.")
return None
print("Speech generation completed successfully")
return (24000, audio_samples)
except Exception as e:
error_msg = f"Error generating speech: {str(e)}"
print(error_msg)
gr.Error(error_msg)
return None
# Examples - reduced for quota management
examples = [
["ααααΆααα½α <laugh> αααα»αααααα Kiri α αΎααααα»αααΆ AI αααα’αΆα
ααααααα’ααααααα
ααΆαααααα"],
["αααα»αα’αΆα
αααααΎαααααααα·ααΆααααααα ααΌα
ααΆ <laugh> ααΎα
α"],
["αααα·ααα·α αααα»αααΎαααααΆαα½αααααΆααααα
αΆααααααα»ααααα½αα―αα <laugh> ααΆαα½αα²ααα’ααααααΎα
ααΆααα"],
["αααα»ααααα
αααα αΌα ααααΆααααααααΎαααα»αααααΏαααααααα₯αααα <chuckle> ααΆαααα‘αΆααα’ααα αΎαα"],
["αααααααα ααααΆαα ααααΎααΆαααααα½αααααα <sigh> α
αααα
αααααααααΆαα αΎαα"],
]
EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
# Create custom CSS
css = """
.gradio-container {
max-width: 1200px;
margin: auto;
padding-top: 1.5rem;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
.generate-btn {
background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
}
.clear-btn {
background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important;
border: none !important;
color: white !important;
}
.char-counter {
font-size: 12px;
color: #666;
text-align: right;
margin-top: 5px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
<div class="main-header">
# π΅ Khmer Text-to-Speech
**αααΌαααααααααα’αααααααΆααααα**
αααα
αΌαα’ααααααααααααααα’ααα α αΎαααααΆααααΆααααααααα
ααΆααααααα·ααΆαα
π‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
</div>
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Enter Khmer text (αααα
αΌαα’αααααααααα) - Max 150 characters",
placeholder="αααα
αΌαα’ααααααααααααααα’ααααα
ααΈααα... (α’αα·ααααΆ α‘α₯α αα½α’αααα)",
lines=4,
max_lines=6,
interactive=True,
max_length=150 # Built-in Gradio character limit
)
# Simple character counter
char_info = gr.Textbox(
value="Characters: 0/150",
interactive=False,
show_label=False,
container=False,
elem_classes=["char-counter"]
)
# Advanced Settings
with gr.Accordion("π§ Advanced Settings", open=False):
with gr.Row():
temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.6, step=0.05,
label="Temperature",
info="Higher values create more expressive speech"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top P",
info="Nucleus sampling threshold"
)
with gr.Row():
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.05,
label="Repetition Penalty",
info="Higher values discourage repetitive patterns"
)
max_new_tokens = gr.Slider(
minimum=100, maximum=2000, value=1200, step=50,
label="Max Length",
info="Maximum length of generated audio"
)
with gr.Row():
submit_btn = gr.Button("π€ Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"])
clear_btn = gr.Button("ποΈ Clear", size="lg", elem_classes=["clear-btn"])
with gr.Column(scale=1):
audio_output = gr.Audio(
label="Generated Speech (αααααααααααααΎαα‘αΎα)",
type="numpy",
show_label=True,
interactive=False
)
# Set up examples (NO GPU function calls)
gr.Examples(
examples=examples,
inputs=[text_input],
cache_examples=False,
label="π Example Texts (α’αααααααααΌ) - Click example then press Generate"
)
# Character counter - only updates when focus lost or generation clicked
text_input.blur(
fn=update_char_count,
inputs=[text_input],
outputs=[char_info]
)
# Set up event handlers
submit_btn.click(
fn=lambda text, temp, top_p, rep_pen, max_tok: [
generate_speech(text, temp, top_p, rep_pen, max_tok),
update_char_count(text)
],
inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
outputs=[audio_output, char_info],
show_progress=True
)
clear_btn.click(
fn=lambda: ("", None, "Characters: 0/150"),
inputs=[],
outputs=[text_input, audio_output, char_info]
)
# Add keyboard shortcut
text_input.submit(
fn=lambda text, temp, top_p, rep_pen, max_tok: [
generate_speech(text, temp, top_p, rep_pen, max_tok),
update_char_count(text)
],
inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
outputs=[audio_output, char_info],
show_progress=True
)
# Launch with embed-friendly optimizations
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.queue(
max_size=3, # Small queue for embeds
default_concurrency_limit=1 # One user at a time
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False,
auth_message="Login to HuggingFace recommended for better GPU quota"
) |