khmer-tts / app.py
mrrtmob's picture
Update generate_speech function to accept max_new_tokens as a parameter and adjust default slider value to 1200
1f2a815
import os
import time
from functools import wraps
import spaces
from snac import SNAC
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, login
from dotenv import load_dotenv
load_dotenv()
# Rate limiting
last_request_time = {}
REQUEST_COOLDOWN = 30
def rate_limit(func):
@wraps(func)
def wrapper(*args, **kwargs):
user_id = "anonymous"
current_time = time.time()
if user_id in last_request_time:
time_since_last = current_time - last_request_time[user_id]
if time_since_last < REQUEST_COOLDOWN:
remaining = int(REQUEST_COOLDOWN - time_since_last)
gr.Warning(f"Please wait {remaining} seconds before next request.")
return None
last_request_time[user_id] = current_time
return func(*args, **kwargs)
return wrapper
# Get HF token from environment variables
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
print("SNAC model loaded successfully")
model_name = "mrrtmob/tts-khm-kore"
print(f"Downloading model files from {model_name}...")
# Download only model config and safetensors with token
snapshot_download(
repo_id=model_name,
token=hf_token,
allow_patterns=[
"config.json",
"*.safetensors",
"model.safetensors.index.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt"
],
ignore_patterns=[
"optimizer.pt",
"pytorch_model.bin",
"training_args.bin",
"scheduler.pt"
]
)
print("Model files downloaded successfully")
print("Loading main model...")
# Load model and tokenizer with token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
token=hf_token
)
model = model.to(device)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token
)
print(f"Khmer TTS model loaded successfully to {device}")
# Process text prompt
def process_prompt(prompt, voice, tokenizer, device):
prompt = f"{voice}: {prompt}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
# No padding needed for single input
attention_mask = torch.ones_like(modified_input_ids)
return modified_input_ids.to(device), attention_mask.to(device)
# Parse output tokens to audio
def parse_output(generated_ids):
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != token_to_remove]
processed_rows.append(masked_row)
code_lists = []
for row in processed_rows:
row_length = row.size(0)
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t - 128266 for t in trimmed_row]
code_lists.append(trimmed_row)
return code_lists[0] if code_lists else []
# Redistribute codes for audio generation
def redistribute_codes(code_list, snac_model):
if not code_list:
return None
device = next(snac_model.parameters()).device
layer_1 = []
layer_2 = []
layer_3 = []
for i in range((len(code_list)+1)//7):
if 7*i < len(code_list):
layer_1.append(code_list[7*i])
if 7*i+1 < len(code_list):
layer_2.append(code_list[7*i+1]-4096)
if 7*i+2 < len(code_list):
layer_3.append(code_list[7*i+2]-(2*4096))
if 7*i+3 < len(code_list):
layer_3.append(code_list[7*i+3]-(3*4096))
if 7*i+4 < len(code_list):
layer_2.append(code_list[7*i+4]-(4*4096))
if 7*i+5 < len(code_list):
layer_3.append(code_list[7*i+5]-(5*4096))
if 7*i+6 < len(code_list):
layer_3.append(code_list[7*i+6]-(6*4096))
if not layer_1:
return None
codes = [
torch.tensor(layer_1, device=device).unsqueeze(0),
torch.tensor(layer_2, device=device).unsqueeze(0),
torch.tensor(layer_3, device=device).unsqueeze(0)
]
audio_hat = snac_model.decode(codes)
return audio_hat.detach().squeeze().cpu().numpy()
# Simple character counter function (only called when needed)
def update_char_count(text):
"""Simple character counter - no text modification"""
count = len(text) if text else 0
return f"Characters: {count}/150"
# Main generation function with rate limiting
@rate_limit
@spaces.GPU(duration=45)
def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
if not text.strip():
gr.Warning("Please enter some text to generate speech.")
return None
# Check length and truncate if needed
if len(text) > 150:
text = text[:150]
gr.Warning("Text was truncated to 150 characters.")
try:
progress(0.1, "Processing text...")
print(f"Generating speech for text: {text[:50]}...")
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
progress(0.3, "Generating speech tokens...")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
eos_token_id=128258,
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id
)
progress(0.6, "Processing speech tokens...")
code_list = parse_output(generated_ids)
if not code_list:
gr.Warning("Failed to generate valid audio codes.")
return None
progress(0.8, "Converting to audio...")
audio_samples = redistribute_codes(code_list, snac_model)
if audio_samples is None:
gr.Warning("Failed to convert codes to audio.")
return None
print("Speech generation completed successfully")
return (24000, audio_samples)
except Exception as e:
error_msg = f"Error generating speech: {str(e)}"
print(error_msg)
gr.Error(error_msg)
return None
# Examples - reduced for quota management
examples = [
["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš <laugh> αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‡αžΆ AI αžŠαŸ‚αž›αž’αžΆαž…αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αŸ”"],
["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž…αŸ”"],
["αž˜αŸ’αžŸαž·αž›αž˜αž·αž‰ αžαŸ’αž‰αž»αŸ†αžƒαžΎαž‰αž†αŸ’αž˜αžΆαž˜αž½αž™αž€αŸ’αž”αžΆαž›αžŠαŸαž‰αž…αžΆαž”αŸ‹αž€αž“αŸ’αž‘αž»αž™αžαŸ’αž›αž½αž“αž―αž„αŸ” <laugh> αžœαžΆαž‚αž½αžšαž²αŸ’αž™αž’αžŸαŸ‹αžŸαŸ†αžŽαžΎαž…αžŽαžΆαžŸαŸ‹αŸ”"],
["αžαŸ’αž‰αž»αŸ†αžšαŸ€αž”αž…αŸ†αž˜αŸ’αž αžΌαž” αžŸαŸ’αžšαžΆαž”αŸ‹αžαŸ‚αž’αŸ’αžœαžΎαž‡αŸ’αžšαž»αŸ‡αž‚αŸ’αžšαžΏαž„αž‘αŸαžŸαž–αŸαž‰αž₯αžŠαŸ’αž‹αŸ” <chuckle> αžœαžΆαž”αŸ’αžšαž‘αžΆαž€αŸ‹αž’αžŸαŸ‹αž αžΎαž™αŸ”"],
["αžαŸ’αž„αŸƒαž“αŸαŸ‡αž αžαŸ‹αžŽαžΆαžŸαŸ‹ αž’αŸ’αžœαžΎαž€αžΆαžšαž–αŸαž‰αž˜αž½αž™αžαŸ’αž„αŸƒαŸ” <sigh> αž…αž„αŸ‹αž‘αŸ…αž•αŸ’αž‘αŸ‡αžŸαž˜αŸ’αžšαžΆαž€αž αžΎαž™αŸ”"],
]
EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
# Create custom CSS
css = """
.gradio-container {
max-width: 1200px;
margin: auto;
padding-top: 1.5rem;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
.generate-btn {
background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
}
.clear-btn {
background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important;
border: none !important;
color: white !important;
}
.char-counter {
font-size: 12px;
color: #666;
text-align: right;
margin-top: 5px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
<div class="main-header">
# 🎡 Khmer Text-to-Speech
**αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
</div>
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš) - Max 150 characters",
placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžαž·αž”αžšαž˜αžΆ ៑αŸ₯០ αžαž½αž’αž€αŸ’αžŸαžš)",
lines=4,
max_lines=6,
interactive=True,
max_length=150 # Built-in Gradio character limit
)
# Simple character counter
char_info = gr.Textbox(
value="Characters: 0/150",
interactive=False,
show_label=False,
container=False,
elem_classes=["char-counter"]
)
# Advanced Settings
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
with gr.Row():
temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.6, step=0.05,
label="Temperature",
info="Higher values create more expressive speech"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top P",
info="Nucleus sampling threshold"
)
with gr.Row():
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.05,
label="Repetition Penalty",
info="Higher values discourage repetitive patterns"
)
max_new_tokens = gr.Slider(
minimum=100, maximum=2000, value=1200, step=50,
label="Max Length",
info="Maximum length of generated audio"
)
with gr.Row():
submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"])
clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg", elem_classes=["clear-btn"])
with gr.Column(scale=1):
audio_output = gr.Audio(
label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
type="numpy",
show_label=True,
interactive=False
)
# Set up examples (NO GPU function calls)
gr.Examples(
examples=examples,
inputs=[text_input],
cache_examples=False,
label="πŸ“ Example Texts (αž’αžαŸ’αžαž”αž‘αž‚αŸ†αžšαžΌ) - Click example then press Generate"
)
# Character counter - only updates when focus lost or generation clicked
text_input.blur(
fn=update_char_count,
inputs=[text_input],
outputs=[char_info]
)
# Set up event handlers
submit_btn.click(
fn=lambda text, temp, top_p, rep_pen, max_tok: [
generate_speech(text, temp, top_p, rep_pen, max_tok),
update_char_count(text)
],
inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
outputs=[audio_output, char_info],
show_progress=True
)
clear_btn.click(
fn=lambda: ("", None, "Characters: 0/150"),
inputs=[],
outputs=[text_input, audio_output, char_info]
)
# Add keyboard shortcut
text_input.submit(
fn=lambda text, temp, top_p, rep_pen, max_tok: [
generate_speech(text, temp, top_p, rep_pen, max_tok),
update_char_count(text)
],
inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
outputs=[audio_output, char_info],
show_progress=True
)
# Launch with embed-friendly optimizations
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.queue(
max_size=3, # Small queue for embeds
default_concurrency_limit=1 # One user at a time
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False,
auth_message="Login to HuggingFace recommended for better GPU quota"
)