File size: 14,164 Bytes
5193c5e
3d2ce0a
 
8909289
 
4b72318
8909289
ecf63b4
3d2ce0a
844e3a3
 
5193c5e
3d2ce0a
 
80bd4c9
3d2ce0a
 
 
 
 
 
 
 
 
 
 
80bd4c9
3d2ce0a
 
 
 
 
 
5193c5e
 
 
 
c0c6352
844e3a3
ecf63b4
5193c5e
8bdeb28
 
 
80bd4c9
5193c5e
8bdeb28
80bd4c9
814bea6
80bd4c9
8bdeb28
 
c0c6352
8bdeb28
 
 
 
c0c6352
 
 
 
 
8bdeb28
 
 
 
 
c0c6352
8bdeb28
 
80bd4c9
5193c5e
80bd4c9
 
5193c5e
 
 
731d214
5193c5e
e173776
5193c5e
80bd4c9
5193c5e
 
c0c6352
5193c5e
c0c6352
5193c5e
8bdeb28
844e3a3
d1e3c74
844e3a3
80bd4c9
 
 
814bea6
80bd4c9
d1e3c74
844e3a3
5193c5e
80bd4c9
844e3a3
ecf63b4
 
8bdeb28
814bea6
844e3a3
 
 
ecf63b4
844e3a3
814bea6
844e3a3
 
 
 
814bea6
844e3a3
 
 
 
 
8bdeb28
844e3a3
814bea6
e173776
5193c5e
80bd4c9
8bdeb28
c0c6352
 
 
e173776
844e3a3
 
 
c0c6352
8bdeb28
c0c6352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdeb28
 
 
 
 
 
e173776
5193c5e
54096db
 
 
 
 
3d2ce0a
80bd4c9
3d2ce0a
80bd4c9
1f2a815
8bdeb28
3d2ce0a
8bdeb28
c0c6352
54096db
 
 
 
3d2ce0a
4b72318
8bdeb28
54096db
3d2ce0a
54096db
c0c6352
8bdeb28
844e3a3
80bd4c9
ecf63b4
 
 
 
 
 
8909289
ecf63b4
 
c0c6352
ecf63b4
c0c6352
8bdeb28
844e3a3
c0c6352
 
3d2ce0a
c0c6352
 
8bdeb28
80bd4c9
c0c6352
 
3d2ce0a
c0c6352
 
80bd4c9
e173776
c0c6352
63d778a
3d2ce0a
 
 
63d778a
5193c5e
80bd4c9
8909289
54096db
25f78d4
c0c6352
 
 
8909289
5193c5e
844e3a3
5193c5e
54096db
3d2ce0a
 
 
 
 
 
80bd4c9
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2ce0a
 
 
 
 
 
 
 
 
8bdeb28
3d2ce0a
844e3a3
80bd4c9
 
844e3a3
 
c0c6352
844e3a3
c0c6352
844e3a3
3d2ce0a
 
80bd4c9
c0c6352
 
 
 
814bea6
 
80bd4c9
 
c6ae943
54096db
844e3a3
c0c6352
54096db
 
 
 
 
 
 
 
3d2ce0a
80bd4c9
c0c6352
 
 
 
80bd4c9
 
c0c6352
 
 
80bd4c9
 
c0c6352
 
 
 
80bd4c9
 
c0c6352
 
fb01dbf
80bd4c9
 
c0c6352
 
 
80bd4c9
 
c0c6352
 
 
 
 
 
 
844e3a3
c0c6352
814bea6
8909289
 
 
3d2ce0a
814bea6
3d2ce0a
 
54096db
 
 
 
 
 
c0c6352
54096db
844e3a3
54096db
 
 
 
8bdeb28
54096db
c0c6352
8909289
c0c6352
8909289
54096db
844e3a3
54096db
3d2ce0a
 
80bd4c9
3d2ce0a
54096db
 
 
 
3d2ce0a
54096db
3d2ce0a
8909289
5193c5e
80bd4c9
8909289
80bd4c9
c0c6352
80bd4c9
 
c0c6352
3d2ce0a
 
c0c6352
 
80bd4c9
 
c0c6352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import os
import time
from functools import wraps
import spaces
from snac import SNAC
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, login
from dotenv import load_dotenv
load_dotenv()

# Rate limiting
last_request_time = {}
REQUEST_COOLDOWN = 30

def rate_limit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        user_id = "anonymous"
        current_time = time.time()
        
        if user_id in last_request_time:
            time_since_last = current_time - last_request_time[user_id]
            if time_since_last < REQUEST_COOLDOWN:
                remaining = int(REQUEST_COOLDOWN - time_since_last)
                gr.Warning(f"Please wait {remaining} seconds before next request.")
                return None
        
        last_request_time[user_id] = current_time
        return func(*args, **kwargs)
    return wrapper

# Get HF token from environment variables
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token)

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
print("SNAC model loaded successfully")

model_name = "mrrtmob/tts-khm-kore"
print(f"Downloading model files from {model_name}...")

# Download only model config and safetensors with token
snapshot_download(
    repo_id=model_name,
    token=hf_token,
    allow_patterns=[
        "config.json",
        "*.safetensors",
        "model.safetensors.index.json",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "vocab.json",
        "merges.txt"
    ],
    ignore_patterns=[
        "optimizer.pt",
        "pytorch_model.bin",
        "training_args.bin",
        "scheduler.pt"
    ]
)
print("Model files downloaded successfully")

print("Loading main model...")
# Load model and tokenizer with token
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.bfloat16,
    token=hf_token
)
model = model.to(device)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    token=hf_token
)
print(f"Khmer TTS model loaded successfully to {device}")

# Process text prompt
def process_prompt(prompt, voice, tokenizer, device):
    prompt = f"{voice}: {prompt}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    start_token = torch.tensor([[128259]], dtype=torch.int64)  # Start of human
    end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)  # End of text, End of human
    modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)  # SOH SOT Text EOT EOH
    
    # No padding needed for single input
    attention_mask = torch.ones_like(modified_input_ids)
    return modified_input_ids.to(device), attention_mask.to(device)

# Parse output tokens to audio
def parse_output(generated_ids):
    token_to_find = 128257
    token_to_remove = 128258
    token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
    
    if len(token_indices[1]) > 0:
        last_occurrence_idx = token_indices[1][-1].item()
        cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
    else:
        cropped_tensor = generated_ids
    
    processed_rows = []
    for row in cropped_tensor:
        masked_row = row[row != token_to_remove]
        processed_rows.append(masked_row)
    
    code_lists = []
    for row in processed_rows:
        row_length = row.size(0)
        new_length = (row_length // 7) * 7
        trimmed_row = row[:new_length]
        trimmed_row = [t - 128266 for t in trimmed_row]
        code_lists.append(trimmed_row)
    
    return code_lists[0] if code_lists else []

# Redistribute codes for audio generation
def redistribute_codes(code_list, snac_model):
    if not code_list:
        return None
    
    device = next(snac_model.parameters()).device
    layer_1 = []
    layer_2 = []
    layer_3 = []
    
    for i in range((len(code_list)+1)//7):
        if 7*i < len(code_list):
            layer_1.append(code_list[7*i])
        if 7*i+1 < len(code_list):
            layer_2.append(code_list[7*i+1]-4096)
        if 7*i+2 < len(code_list):
            layer_3.append(code_list[7*i+2]-(2*4096))
        if 7*i+3 < len(code_list):
            layer_3.append(code_list[7*i+3]-(3*4096))
        if 7*i+4 < len(code_list):
            layer_2.append(code_list[7*i+4]-(4*4096))
        if 7*i+5 < len(code_list):
            layer_3.append(code_list[7*i+5]-(5*4096))
        if 7*i+6 < len(code_list):
            layer_3.append(code_list[7*i+6]-(6*4096))
    
    if not layer_1:
        return None
    
    codes = [
        torch.tensor(layer_1, device=device).unsqueeze(0),
        torch.tensor(layer_2, device=device).unsqueeze(0),
        torch.tensor(layer_3, device=device).unsqueeze(0)
    ]
    audio_hat = snac_model.decode(codes)
    return audio_hat.detach().squeeze().cpu().numpy()

# Simple character counter function (only called when needed)
def update_char_count(text):
    """Simple character counter - no text modification"""
    count = len(text) if text else 0
    return f"Characters: {count}/150"

# Main generation function with rate limiting
@rate_limit
@spaces.GPU(duration=45)
def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
    if not text.strip():
        gr.Warning("Please enter some text to generate speech.")
        return None
    
    # Check length and truncate if needed
    if len(text) > 150:
        text = text[:150]
        gr.Warning("Text was truncated to 150 characters.")
    
    try:
        progress(0.1, "Processing text...")
        print(f"Generating speech for text: {text[:50]}...")
        
        input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
        
        progress(0.3, "Generating speech tokens...")
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                num_return_sequences=1,
                eos_token_id=128258,
                pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id
            )
        
        progress(0.6, "Processing speech tokens...")
        code_list = parse_output(generated_ids)
        
        if not code_list:
            gr.Warning("Failed to generate valid audio codes.")
            return None
        
        progress(0.8, "Converting to audio...")
        audio_samples = redistribute_codes(code_list, snac_model)
        
        if audio_samples is None:
            gr.Warning("Failed to convert codes to audio.")
            return None
        
        print("Speech generation completed successfully")
        return (24000, audio_samples)
        
    except Exception as e:
        error_msg = f"Error generating speech: {str(e)}"
        print(error_msg)
        gr.Error(error_msg)
        return None

# Examples - reduced for quota management
examples = [
    ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš <laugh> αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‡αžΆ AI αžŠαŸ‚αž›αž’αžΆαž…αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αŸ”"],
    ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž…αŸ”"],
    ["αž˜αŸ’αžŸαž·αž›αž˜αž·αž‰ αžαŸ’αž‰αž»αŸ†αžƒαžΎαž‰αž†αŸ’αž˜αžΆαž˜αž½αž™αž€αŸ’αž”αžΆαž›αžŠαŸαž‰αž…αžΆαž”αŸ‹αž€αž“αŸ’αž‘αž»αž™αžαŸ’αž›αž½αž“αž―αž„αŸ” <laugh> αžœαžΆαž‚αž½αžšαž²αŸ’αž™αž’αžŸαŸ‹αžŸαŸ†αžŽαžΎαž…αžŽαžΆαžŸαŸ‹αŸ”"],
    ["αžαŸ’αž‰αž»αŸ†αžšαŸ€αž”αž…αŸ†αž˜αŸ’αž αžΌαž” αžŸαŸ’αžšαžΆαž”αŸ‹αžαŸ‚αž’αŸ’αžœαžΎαž‡αŸ’αžšαž»αŸ‡αž‚αŸ’αžšαžΏαž„αž‘αŸαžŸαž–αŸαž‰αž₯αžŠαŸ’αž‹αŸ” <chuckle> αžœαžΆαž”αŸ’αžšαž‘αžΆαž€αŸ‹αž’αžŸαŸ‹αž αžΎαž™αŸ”"],
    ["αžαŸ’αž„αŸƒαž“αŸαŸ‡αž αžαŸ‹αžŽαžΆαžŸαŸ‹ αž’αŸ’αžœαžΎαž€αžΆαžšαž–αŸαž‰αž˜αž½αž™αžαŸ’αž„αŸƒαŸ” <sigh> αž…αž„αŸ‹αž‘αŸ…αž•αŸ’αž‘αŸ‡αžŸαž˜αŸ’αžšαžΆαž€αž αžΎαž™αŸ”"],
]

EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]

# Create custom CSS
css = """
.gradio-container {
    max-width: 1200px;
    margin: auto;
    padding-top: 1.5rem;
}
.main-header {
    text-align: center;
    margin-bottom: 2rem;
}
.generate-btn {
    background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
    border: none !important;
    color: white !important;
    font-weight: bold !important;
}
.clear-btn {
    background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important;
    border: none !important;
    color: white !important;
}
.char-counter {
    font-size: 12px;
    color: #666;
    text-align: right;
    margin-top: 5px;
}
"""

# Create Gradio interface
with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"""
    <div class="main-header">
    
    # 🎡 Khmer Text-to-Speech
    **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
    
    αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
    
    πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
    
    </div>
    """)    
    
    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš) - Max 150 characters",
                placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžαž·αž”αžšαž˜αžΆ ៑αŸ₯០ αžαž½αž’αž€αŸ’αžŸαžš)",
                lines=4,
                max_lines=6,
                interactive=True,
                max_length=150  # Built-in Gradio character limit
            )
            
            # Simple character counter
            char_info = gr.Textbox(
                value="Characters: 0/150",
                interactive=False,
                show_label=False,
                container=False,
                elem_classes=["char-counter"]
            )
            
            # Advanced Settings
            with gr.Accordion("πŸ”§ Advanced Settings", open=False):
                with gr.Row():
                    temperature = gr.Slider(
                        minimum=0.1, maximum=1.5, value=0.6, step=0.05,
                        label="Temperature",
                        info="Higher values create more expressive speech"
                    )
                    top_p = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.95, step=0.05,
                        label="Top P",
                        info="Nucleus sampling threshold"
                    )
                with gr.Row():
                    repetition_penalty = gr.Slider(
                        minimum=1.0, maximum=2.0, value=1.1, step=0.05,
                        label="Repetition Penalty",
                        info="Higher values discourage repetitive patterns"
                    )
                    max_new_tokens = gr.Slider(
                        minimum=100, maximum=2000, value=1200, step=50,
                        label="Max Length",
                        info="Maximum length of generated audio"
                    )
            
            with gr.Row():
                submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"])
                clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg", elem_classes=["clear-btn"])
        
        with gr.Column(scale=1):
            audio_output = gr.Audio(
                label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
                type="numpy",
                show_label=True,
                interactive=False
            )
    
    # Set up examples (NO GPU function calls)
    gr.Examples(
        examples=examples,
        inputs=[text_input],
        cache_examples=False,
        label="πŸ“ Example Texts (αž’αžαŸ’αžαž”αž‘αž‚αŸ†αžšαžΌ) - Click example then press Generate"
    )
    
    # Character counter - only updates when focus lost or generation clicked
    text_input.blur(
        fn=update_char_count,
        inputs=[text_input],
        outputs=[char_info]
    )
    
    # Set up event handlers
    submit_btn.click(
        fn=lambda text, temp, top_p, rep_pen, max_tok: [
            generate_speech(text, temp, top_p, rep_pen, max_tok),
            update_char_count(text)
        ],
        inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=[audio_output, char_info],
        show_progress=True
    )
    
    clear_btn.click(
        fn=lambda: ("", None, "Characters: 0/150"),
        inputs=[],
        outputs=[text_input, audio_output, char_info]
    )
    
    # Add keyboard shortcut
    text_input.submit(
        fn=lambda text, temp, top_p, rep_pen, max_tok: [
            generate_speech(text, temp, top_p, rep_pen, max_tok),
            update_char_count(text)
        ],
        inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=[audio_output, char_info],
        show_progress=True
    )

# Launch with embed-friendly optimizations
if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.queue(
        max_size=3,  # Small queue for embeds
        default_concurrency_limit=1  # One user at a time
    ).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
        ssr_mode=False,
        auth_message="Login to HuggingFace recommended for better GPU quota"
    )