import os import time from functools import wraps import spaces from snac import SNAC import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download, login from dotenv import load_dotenv load_dotenv() # Rate limiting last_request_time = {} REQUEST_COOLDOWN = 30 def rate_limit(func): @wraps(func) def wrapper(*args, **kwargs): user_id = "anonymous" current_time = time.time() if user_id in last_request_time: time_since_last = current_time - last_request_time[user_id] if time_since_last < REQUEST_COOLDOWN: remaining = int(REQUEST_COOLDOWN - time_since_last) gr.Warning(f"Please wait {remaining} seconds before next request.") return None last_request_time[user_id] = current_time return func(*args, **kwargs) return wrapper # Get HF token from environment variables hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading SNAC model...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") snac_model = snac_model.to(device) print("SNAC model loaded successfully") model_name = "mrrtmob/tts-khm-kore" print(f"Downloading model files from {model_name}...") # Download only model config and safetensors with token snapshot_download( repo_id=model_name, token=hf_token, allow_patterns=[ "config.json", "*.safetensors", "model.safetensors.index.json", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt" ], ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt" ] ) print("Model files downloaded successfully") print("Loading main model...") # Load model and tokenizer with token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, token=hf_token ) model = model.to(device) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_name, token=hf_token ) print(f"Khmer TTS model loaded successfully to {device}") # Process text prompt def process_prompt(prompt, voice, tokenizer, device): prompt = f"{voice}: {prompt}" input_ids = tokenizer(prompt, return_tensors="pt").input_ids start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH # No padding needed for single input attention_mask = torch.ones_like(modified_input_ids) return modified_input_ids.to(device), attention_mask.to(device) # Parse output tokens to audio def parse_output(generated_ids): token_to_find = 128257 token_to_remove = 128258 token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids processed_rows = [] for row in cropped_tensor: masked_row = row[row != token_to_remove] processed_rows.append(masked_row) code_lists = [] for row in processed_rows: row_length = row.size(0) new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t - 128266 for t in trimmed_row] code_lists.append(trimmed_row) return code_lists[0] if code_lists else [] # Redistribute codes for audio generation def redistribute_codes(code_list, snac_model): if not code_list: return None device = next(snac_model.parameters()).device layer_1 = [] layer_2 = [] layer_3 = [] for i in range((len(code_list)+1)//7): if 7*i < len(code_list): layer_1.append(code_list[7*i]) if 7*i+1 < len(code_list): layer_2.append(code_list[7*i+1]-4096) if 7*i+2 < len(code_list): layer_3.append(code_list[7*i+2]-(2*4096)) if 7*i+3 < len(code_list): layer_3.append(code_list[7*i+3]-(3*4096)) if 7*i+4 < len(code_list): layer_2.append(code_list[7*i+4]-(4*4096)) if 7*i+5 < len(code_list): layer_3.append(code_list[7*i+5]-(5*4096)) if 7*i+6 < len(code_list): layer_3.append(code_list[7*i+6]-(6*4096)) if not layer_1: return None codes = [ torch.tensor(layer_1, device=device).unsqueeze(0), torch.tensor(layer_2, device=device).unsqueeze(0), torch.tensor(layer_3, device=device).unsqueeze(0) ] audio_hat = snac_model.decode(codes) return audio_hat.detach().squeeze().cpu().numpy() # Simple character counter function (only called when needed) def update_char_count(text): """Simple character counter - no text modification""" count = len(text) if text else 0 return f"Characters: {count}/150" # Main generation function with rate limiting @rate_limit @spaces.GPU(duration=45) def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()): if not text.strip(): gr.Warning("Please enter some text to generate speech.") return None # Check length and truncate if needed if len(text) > 150: text = text[:150] gr.Warning("Text was truncated to 150 characters.") try: progress(0.1, "Processing text...") print(f"Generating speech for text: {text[:50]}...") input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) progress(0.3, "Generating speech tokens...") with torch.no_grad(): generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=1, eos_token_id=128258, pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id ) progress(0.6, "Processing speech tokens...") code_list = parse_output(generated_ids) if not code_list: gr.Warning("Failed to generate valid audio codes.") return None progress(0.8, "Converting to audio...") audio_samples = redistribute_codes(code_list, snac_model) if audio_samples is None: gr.Warning("Failed to convert codes to audio.") return None print("Speech generation completed successfully") return (24000, audio_samples) except Exception as e: error_msg = f"Error generating speech: {str(e)}" print(error_msg) gr.Error(error_msg) return None # Examples - reduced for quota management examples = [ ["ជំរាបសួរ ខ្ញុំឈ្មោះ Kiri ហើយខ្ញុំជា AI ដែលអាចបម្លែងអត្ថបទទៅជាសំលេង។"], ["ខ្ញុំអាចបង្កើតសំលេងនិយាយផ្សេងៗ ដូចជា សើច។"], ["ម្សិលមិញ ខ្ញុំឃើញឆ្មាមួយក្បាលដេញចាប់កន្ទុយខ្លួនឯង។ វាគួរឲ្យអស់សំណើចណាស់។"], ["ខ្ញុំរៀបចំម្ហូប ស្រាប់តែធ្វើជ្រុះគ្រឿងទេសពេញឥដ្ឋ។ វាប្រឡាក់អស់ហើយ។"], ["ថ្ងៃនេះហត់ណាស់ ធ្វើការពេញមួយថ្ងៃ។ ចង់ទៅផ្ទះសម្រាកហើយ។"], ] EMOTIVE_TAGS = ["``", "``", "``", "``", "``", "``", "``", "``"] # Create custom CSS css = """ .gradio-container { max-width: 1200px; margin: auto; padding-top: 1.5rem; } .main-header { text-align: center; margin-bottom: 2rem; } .generate-btn { background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important; border: none !important; color: white !important; font-weight: bold !important; } .clear-btn { background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important; border: none !important; color: white !important; } .char-counter { font-size: 12px; color: #666; text-align: right; margin-top: 5px; } """ # Create Gradio interface with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(f"""
# 🎵 Khmer Text-to-Speech **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង** បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។ 💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
""") with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ) - Max 150 characters", placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ... (អតិបរមា ១៥០ តួអក្សរ)", lines=4, max_lines=6, interactive=True, max_length=150 # Built-in Gradio character limit ) # Simple character counter char_info = gr.Textbox( value="Characters: 0/150", interactive=False, show_label=False, container=False, elem_classes=["char-counter"] ) # Advanced Settings with gr.Accordion("🔧 Advanced Settings", open=False): with gr.Row(): temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature", info="Higher values create more expressive speech" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P", info="Nucleus sampling threshold" ) with gr.Row(): repetition_penalty = gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Higher values discourage repetitive patterns" ) max_new_tokens = gr.Slider( minimum=100, maximum=2000, value=1200, step=50, label="Max Length", info="Maximum length of generated audio" ) with gr.Row(): submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"]) clear_btn = gr.Button("🗑️ Clear", size="lg", elem_classes=["clear-btn"]) with gr.Column(scale=1): audio_output = gr.Audio( label="Generated Speech (សំលេងដែលបង្កើតឡើង)", type="numpy", show_label=True, interactive=False ) # Set up examples (NO GPU function calls) gr.Examples( examples=examples, inputs=[text_input], cache_examples=False, label="📝 Example Texts (អត្ថបទគំរូ) - Click example then press Generate" ) # Character counter - only updates when focus lost or generation clicked text_input.blur( fn=update_char_count, inputs=[text_input], outputs=[char_info] ) # Set up event handlers submit_btn.click( fn=lambda text, temp, top_p, rep_pen, max_tok: [ generate_speech(text, temp, top_p, rep_pen, max_tok), update_char_count(text) ], inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens], outputs=[audio_output, char_info], show_progress=True ) clear_btn.click( fn=lambda: ("", None, "Characters: 0/150"), inputs=[], outputs=[text_input, audio_output, char_info] ) # Add keyboard shortcut text_input.submit( fn=lambda text, temp, top_p, rep_pen, max_tok: [ generate_speech(text, temp, top_p, rep_pen, max_tok), update_char_count(text) ], inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens], outputs=[audio_output, char_info], show_progress=True ) # Launch with embed-friendly optimizations if __name__ == "__main__": print("Starting Gradio interface...") demo.queue( max_size=3, # Small queue for embeds default_concurrency_limit=1 # One user at a time ).launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, ssr_mode=False, auth_message="Login to HuggingFace recommended for better GPU quota" )