mrrtmob commited on
Commit
63d778a
·
1 Parent(s): 02daaea
Files changed (1) hide show
  1. app.py +210 -855
app.py CHANGED
@@ -1,876 +1,231 @@
1
- import gradio as gr
 
2
  import torch
3
- import numpy as np
4
- import os
5
- import locale
6
-
7
- # Set UTF-8 encoding
8
- locale.getpreferredencoding = lambda: "UTF-8"
9
-
10
- # Try different import methods for unsloth
11
- try:
12
- from unsloth import FastLanguageModel
13
- UNSLOTH_AVAILABLE = True
14
- except ImportError:
15
- try:
16
- # Fallback import
17
- import unsloth
18
- from unsloth import FastLanguageModel
19
- UNSLOTH_AVAILABLE = True
20
- except ImportError:
21
- print("Warning: Unsloth not available, using transformers fallback")
22
- from transformers import AutoModelForCausalLM, AutoTokenizer
23
- UNSLOTH_AVAILABLE = False
24
-
25
- # Import SNAC
26
- try:
27
- from snac import SNAC
28
- SNAC_AVAILABLE = True
29
- except ImportError:
30
- print("Error: SNAC not available")
31
- SNAC_AVAILABLE = False
32
-
33
- class TTSKhmerModel:
34
- def __init__(self):
35
- self.model = None
36
- self.tokenizer = None
37
- self.snac_model = None
38
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
- self.current_model = None
40
- print(f"Using device: {self.device}")
41
- print(f"Unsloth available: {UNSLOTH_AVAILABLE}")
42
- print(f"SNAC available: {SNAC_AVAILABLE}")
43
-
44
- def load_models(self, model_name="mrrtmob/tts-khm"):
45
- """Load the TTS model and SNAC model"""
46
- try:
47
- if not SNAC_AVAILABLE:
48
- return False, "SNAC model not available"
49
-
50
- # Check if we need to reload the model
51
- if self.current_model != model_name:
52
- print(f"Loading TTS model: {model_name}...")
53
-
54
- if UNSLOTH_AVAILABLE:
55
- # Use unsloth
56
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
57
- model_name=model_name,
58
- max_seq_length=2048,
59
- dtype=None,
60
- load_in_4bit=False if self.device == "cuda" else True,
61
- )
62
- # Enable inference mode
63
- FastLanguageModel.for_inference(self.model)
64
- else:
65
- # Fallback to transformers
66
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
67
- self.model = AutoModelForCausalLM.from_pretrained(
68
- model_name,
69
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
70
- device_map="auto" if self.device == "cuda" else None
71
- )
72
-
73
- self.current_model = model_name
74
- print(f"TTS model '{model_name}' loaded successfully!")
75
-
76
- # Load SNAC model if not already loaded
77
- if self.snac_model is None:
78
- print("Loading SNAC model...")
79
- self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
80
- # Keep SNAC on CPU to save VRAM
81
- self.snac_model = self.snac_model.to("cpu")
82
- print("SNAC model loaded successfully!")
83
-
84
- return True, f"Model '{model_name}' loaded successfully"
85
-
86
- except Exception as e:
87
- error_msg = f"Error loading model '{model_name}': {e}"
88
- print(error_msg)
89
- return False, error_msg
90
-
91
- def redistribute_codes(self, code_list):
92
- """Convert code list to audio using SNAC decoder"""
93
- layer_1 = []
94
- layer_2 = []
95
- layer_3 = []
96
-
97
- for i in range((len(code_list)+1)//7):
98
- if 7*i < len(code_list):
99
- layer_1.append(code_list[7*i])
100
- if 7*i+1 < len(code_list):
101
- layer_2.append(code_list[7*i+1]-4096)
102
- if 7*i+2 < len(code_list):
103
- layer_3.append(code_list[7*i+2]-(2*4096))
104
- if 7*i+3 < len(code_list):
105
- layer_3.append(code_list[7*i+3]-(3*4096))
106
- if 7*i+4 < len(code_list):
107
- layer_2.append(code_list[7*i+4]-(4*4096))
108
- if 7*i+5 < len(code_list):
109
- layer_3.append(code_list[7*i+5]-(5*4096))
110
- if 7*i+6 < len(code_list):
111
- layer_3.append(code_list[7*i+6]-(6*4096))
112
-
113
- codes = [
114
- torch.tensor(layer_1).unsqueeze(0),
115
- torch.tensor(layer_2).unsqueeze(0),
116
- torch.tensor(layer_3).unsqueeze(0)
117
- ]
118
-
119
- # Move SNAC to GPU temporarily for decoding if available
120
- if self.device == "cuda":
121
- self.snac_model = self.snac_model.to("cuda")
122
- codes = [c.to("cuda") for c in codes]
123
-
124
- # Decode audio
125
- with torch.no_grad():
126
- audio_hat = self.snac_model.decode(codes)
127
-
128
- # Move back to CPU to save memory
129
- if self.device == "cuda":
130
- audio_hat = audio_hat.cpu()
131
- self.snac_model = self.snac_model.to("cpu")
132
- torch.cuda.empty_cache()
133
-
134
- return audio_hat
135
-
136
- def generate_speech(self, text, voice="Elise", temperature=0.6, top_p=0.95):
137
- """Generate speech from text"""
138
- if not self.model or not self.tokenizer or not self.snac_model:
139
- return None, "Models not loaded properly"
140
-
141
- try:
142
- # Prepare prompt
143
- prompt = f"{voice}: {text}" if voice else text
144
-
145
- # Tokenize
146
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
147
-
148
- # Add special tokens
149
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
150
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
151
-
152
- # Combine tokens
153
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
154
-
155
- # Create attention mask
156
- attention_mask = torch.ones_like(modified_input_ids)
157
-
158
- # Move to device
159
- input_ids = modified_input_ids.to(self.device)
160
- attention_mask = attention_mask.to(self.device)
161
-
162
- # Generate
163
- with torch.no_grad():
164
- generated_ids = self.model.generate(
165
- input_ids=input_ids,
166
- attention_mask=attention_mask,
167
- max_new_tokens=1200,
168
- do_sample=True,
169
- temperature=temperature,
170
- top_p=top_p,
171
- repetition_penalty=1.1,
172
- num_return_sequences=1,
173
- eos_token_id=128258,
174
- use_cache=True,
175
- pad_token_id=self.tokenizer.eos_token_id
176
- )
177
-
178
- # Clear GPU cache
179
- if self.device == "cuda":
180
- torch.cuda.empty_cache()
181
-
182
- # Process generated tokens
183
- token_to_find = 128257
184
- token_to_remove = 128258
185
-
186
- # Find last occurrence of token_to_find
187
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
188
- if len(token_indices[1]) > 0:
189
- last_occurrence_idx = token_indices[1][-1].item()
190
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
191
- else:
192
- cropped_tensor = generated_ids
193
-
194
- # Remove unwanted tokens
195
- row = cropped_tensor[0]
196
- row = row[row != token_to_remove]
197
-
198
- # Process codes
199
- row_length = row.size(0)
200
- new_length = (row_length // 7) * 7
201
- trimmed_row = row[:new_length]
202
- code_list = [t.item() - 128266 for t in trimmed_row]
203
-
204
- if len(code_list) == 0:
205
- return None, "No valid audio tokens generated"
206
-
207
- # Generate audio
208
- audio_tensor = self.redistribute_codes(code_list)
209
- audio_array = audio_tensor.detach().squeeze().cpu().numpy()
210
-
211
- # Convert to proper format for Gradio
212
- sample_rate = 24000
213
- return (sample_rate, audio_array), "✅ Speech generated successfully!"
214
-
215
- except Exception as e:
216
- return None, f"❌ Error generating speech: {str(e)}"
217
-
218
- # Initialize the model
219
- tts_model = TTSKhmerModel()
220
-
221
- def initialize_models(model_name):
222
- """Initialize models on startup"""
223
- print("Initializing models...")
224
- success, message = tts_model.load_models(model_name)
225
- gpu_info = f"GPU available: {torch.cuda.is_available()}"
226
- if torch.cuda.is_available():
227
- gpu_info += f" ({torch.cuda.get_device_name(0)})"
228
 
229
- if success:
230
- return f"✅ {message}! {gpu_info}"
231
- else:
232
- return f"❌ {message}. {gpu_info}"
233
-
234
- def change_model(model_name):
235
- """Change the TTS model"""
236
- if not model_name.strip():
237
- return "⚠️ Please enter a valid model name"
238
 
239
- success, message = tts_model.load_models(model_name.strip())
240
- return message
241
-
242
- def text_to_speech(text, voice, temperature, top_p):
243
- """Gradio interface function"""
244
- if not text.strip():
245
- return None, "⚠️ Please enter some text"
246
 
247
- if not SNAC_AVAILABLE:
248
- return None, "❌ SNAC model not available. Please check installation."
249
 
250
- print(f"Generating speech for: {text[:50]}...")
251
- audio_output, message = tts_model.generate_speech(text, voice, temperature, top_p)
252
- return audio_output, message
253
-
254
- # Elegant and smooth CSS
255
- custom_css = """
256
- /* Import Google Fonts */
257
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
258
-
259
- /* Root variables for consistent theming */
260
- :root {
261
- --primary-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
262
- --secondary-gradient: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
263
- --tertiary-gradient: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
264
- --quaternary-gradient: linear-gradient(135deg, #43e97b 0%, #38f9d7 100%);
265
- --glass-bg: rgba(255, 255, 255, 0.1);
266
- --glass-border: rgba(255, 255, 255, 0.2);
267
- --text-primary: #2d3748;
268
- --text-secondary: #4a5568;
269
- --shadow-light: 0 4px 20px rgba(0, 0, 0, 0.08);
270
- --shadow-medium: 0 8px 30px rgba(0, 0, 0, 0.12);
271
- --shadow-heavy: 0 12px 40px rgba(0, 0, 0, 0.15);
272
- --border-radius: 16px;
273
- --transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
274
- }
275
-
276
- /* Global styling */
277
- * {
278
- font-family: 'Inter', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
279
- transition: var(--transition);
280
- }
281
-
282
- .gradio-container {
283
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
284
- min-height: 100vh;
285
- padding: 20px;
286
- }
287
-
288
- /* Header Section */
289
- .header-container {
290
- text-align: center;
291
- background: var(--glass-bg);
292
- backdrop-filter: blur(20px);
293
- border: 1px solid var(--glass-border);
294
- border-radius: 24px;
295
- padding: 2.5rem;
296
- margin-bottom: 2rem;
297
- box-shadow: var(--shadow-medium);
298
- position: relative;
299
- overflow: hidden;
300
- }
301
-
302
- .header-container::before {
303
- content: '';
304
- position: absolute;
305
- top: 0;
306
- left: 0;
307
- right: 0;
308
- bottom: 0;
309
- background: var(--primary-gradient);
310
- opacity: 0.1;
311
- z-index: -1;
312
- }
313
-
314
- .main-title {
315
- font-size: 3rem;
316
- font-weight: 700;
317
- background: var(--primary-gradient);
318
- -webkit-background-clip: text;
319
- -webkit-text-fill-color: transparent;
320
- background-clip: text;
321
- margin: 0 0 1rem 0;
322
- line-height: 1.2;
323
- }
324
-
325
- .subtitle {
326
- font-size: 1.25rem;
327
- color: var(--text-secondary);
328
- margin: 0 0 0.5rem 0;
329
- font-weight: 500;
330
- }
331
-
332
- .feature-badges {
333
- display: flex;
334
- justify-content: center;
335
- gap: 1rem;
336
- flex-wrap: wrap;
337
- margin-top: 1.5rem;
338
- }
339
-
340
- .badge {
341
- background: var(--glass-bg);
342
- backdrop-filter: blur(10px);
343
- border: 1px solid var(--glass-border);
344
- padding: 0.5rem 1rem;
345
- border-radius: 50px;
346
- font-size: 0.875rem;
347
- font-weight: 500;
348
- color: var(--text-primary);
349
- box-shadow: var(--shadow-light);
350
- }
351
-
352
- /* Card styling */
353
- .glass-card {
354
- background: var(--glass-bg);
355
- backdrop-filter: blur(20px);
356
- border: 1px solid var(--glass-border);
357
- border-radius: var(--border-radius);
358
- padding: 1.5rem;
359
- margin: 1rem 0;
360
- box-shadow: var(--shadow-light);
361
- transition: var(--transition);
362
- }
363
-
364
- .glass-card:hover {
365
- box-shadow: var(--shadow-medium);
366
- transform: translateY(-2px);
367
- }
368
-
369
- .card-title {
370
- font-size: 1.25rem;
371
- font-weight: 600;
372
- color: var(--text-primary);
373
- margin: 0 0 1rem 0;
374
- display: flex;
375
- align-items: center;
376
- gap: 0.5rem;
377
- }
378
-
379
- /* Input styling */
380
- .smooth-input textarea,
381
- .smooth-input input {
382
- background: rgba(255, 255, 255, 0.7) !important;
383
- backdrop-filter: blur(10px) !important;
384
- border: 2px solid transparent !important;
385
- border-radius: 12px !important;
386
- padding: 1rem !important;
387
- font-size: 1rem !important;
388
- transition: var(--transition) !important;
389
- box-shadow: var(--shadow-light) !important;
390
- }
391
-
392
- .smooth-input textarea:focus,
393
- .smooth-input input:focus {
394
- border-color: #667eea !important;
395
- box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
396
- transform: translateY(-1px) !important;
397
- }
398
-
399
- /* Button styling */
400
- .generate-button {
401
- background: var(--primary-gradient) !important;
402
- border: none !important;
403
- border-radius: 50px !important;
404
- padding: 1rem 2rem !important;
405
- font-size: 1rem !important;
406
- font-weight: 600 !important;
407
- color: white !important;
408
- box-shadow: var(--shadow-medium) !important;
409
- transition: var(--transition) !important;
410
- text-transform: none !important;
411
- letter-spacing: 0.5px !important;
412
- min-height: 50px !important;
413
- }
414
-
415
- .generate-button:hover {
416
- transform: translateY(-2px) !important;
417
- box-shadow: var(--shadow-heavy) !important;
418
- }
419
-
420
- .model-button {
421
- background: var(--tertiary-gradient) !important;
422
- border: none !important;
423
- border-radius: 12px !important;
424
- padding: 0.75rem 1.5rem !important;
425
- font-size: 0.875rem !important;
426
- font-weight: 500 !important;
427
- color: white !important;
428
- box-shadow: var(--shadow-light) !important;
429
- transition: var(--transition) !important;
430
- }
431
-
432
- .model-button:hover {
433
- transform: translateY(-1px) !important;
434
- box-shadow: var(--shadow-medium) !important;
435
- }
436
-
437
- /* Dropdown styling */
438
- .smooth-dropdown select {
439
- background: rgba(255, 255, 255, 0.7) !important;
440
- backdrop-filter: blur(10px) !important;
441
- border: 2px solid transparent !important;
442
- border-radius: 12px !important;
443
- padding: 0.75rem 1rem !important;
444
- font-weight: 500 !important;
445
- color: var(--text-primary) !important;
446
- transition: var(--transition) !important;
447
- }
448
-
449
- .smooth-dropdown select:focus {
450
- border-color: #667eea !important;
451
- box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
452
- }
453
-
454
- /* Slider styling */
455
- .smooth-slider {
456
- background: rgba(255, 255, 255, 0.5) !important;
457
- border-radius: 12px !important;
458
- padding: 1rem !important;
459
- margin: 0.5rem 0 !important;
460
- }
461
-
462
- .smooth-slider input[type="range"] {
463
- background: var(--quaternary-gradient) !important;
464
- height: 6px !important;
465
- border-radius: 3px !important;
466
- }
467
-
468
- /* Status display */
469
- .status-display {
470
- background: rgba(255, 255, 255, 0.8) !important;
471
- border: none !important;
472
- border-radius: 12px !important;
473
- padding: 1rem !important;
474
- font-weight: 500 !important;
475
- text-align: center !important;
476
- box-shadow: var(--shadow-light) !important;
477
- }
478
-
479
- /* Audio player */
480
- .audio-container {
481
- background: rgba(255, 255, 255, 0.6) !important;
482
- border-radius: 16px !important;
483
- padding: 1rem !important;
484
- box-shadow: var(--shadow-light) !important;
485
- backdrop-filter: blur(10px) !important;
486
- }
487
-
488
- /* Examples section */
489
- .examples-grid {
490
- display: grid;
491
- gap: 1rem;
492
- margin-top: 1rem;
493
- }
494
-
495
- .example-card {
496
- background: rgba(255, 255, 255, 0.4);
497
- border: 1px solid var(--glass-border);
498
- border-radius: 12px;
499
- padding: 1rem;
500
- cursor: pointer;
501
- transition: var(--transition);
502
- backdrop-filter: blur(5px);
503
- }
504
-
505
- .example-card:hover {
506
- background: rgba(255, 255, 255, 0.6);
507
- transform: translateY(-1px);
508
- box-shadow: var(--shadow-light);
509
- }
510
-
511
- /* Info section */
512
- .info-grid {
513
- display: grid;
514
- grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
515
- gap: 1rem;
516
- margin-top: 1rem;
517
- }
518
-
519
- .info-item {
520
- background: rgba(255, 255, 255, 0.3);
521
- border-radius: 12px;
522
- padding: 1rem;
523
- backdrop-filter: blur(5px);
524
- border: 1px solid var(--glass-border);
525
- }
526
-
527
- .info-title {
528
- font-size: 1rem;
529
- font-weight: 600;
530
- margin: 0 0 0.5rem 0;
531
- color: var(--text-primary);
532
- }
533
-
534
- .info-content {
535
- font-size: 0.875rem;
536
- color: var(--text-secondary);
537
- line-height: 1.5;
538
- }
539
-
540
- /* Accordion styling */
541
- .accordion-container {
542
- background: rgba(255, 255, 255, 0.3) !important;
543
- border-radius: 12px !important;
544
- border: 1px solid var(--glass-border) !important;
545
- box-shadow: var(--shadow-light) !important;
546
- }
547
-
548
- /* Animation for loading states */
549
- @keyframes pulse {
550
- 0%, 100% { opacity: 1; }
551
- 50% { opacity: 0.7; }
552
- }
553
-
554
- .loading {
555
- animation: pulse 2s infinite;
556
- }
557
-
558
- /* Responsive adjustments */
559
- @media (max-width: 768px) {
560
- .main-title {
561
- font-size: 2rem;
562
- }
563
 
564
- .feature-badges {
565
- flex-direction: column;
566
- align-items: center;
567
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
- .info-grid {
570
- grid-template-columns: 1fr;
571
- }
572
- }
573
-
574
- /* Smooth scrolling */
575
- html {
576
- scroll-behavior: smooth;
577
- }
578
-
579
- /* Custom scrollbar */
580
- ::-webkit-scrollbar {
581
- width: 8px;
582
- }
583
-
584
- ::-webkit-scrollbar-track {
585
- background: rgba(255, 255, 255, 0.1);
586
- border-radius: 4px;
587
- }
588
-
589
- ::-webkit-scrollbar-thumb {
590
- background: var(--primary-gradient);
591
- border-radius: 4px;
592
- }
593
-
594
- ::-webkit-scrollbar-thumb:hover {
595
- background: var(--secondary-gradient);
596
- }
597
- """
598
-
599
- # Create the enhanced Gradio interface
600
- with gr.Blocks(
601
- title="🎤 Advanced Khmer TTS Studio",
602
- theme=gr.themes.Soft(
603
- primary_hue="blue",
604
- secondary_hue="emerald",
605
- neutral_hue="slate",
606
- font=gr.themes.GoogleFont("Inter")
607
- ),
608
- css=custom_css
609
- ) as demo:
610
 
611
- # Beautiful header
612
- gr.HTML("""
613
- <div class="header-container">
614
- <h1 class="main-title">🎤 Advanced Khmer TTS Studio</h1>
615
- <p class="subtitle">Professional AI-Powered Khmer Speech Synthesis Platform</p>
616
- <div class="feature-badges">
617
- <span class="badge">🎯 Multi-Model Support</span>
618
- <span class="badge">🚀 Real-time Processing</span>
619
- <span class="badge">🎭 Multiple Voices</span>
620
- <span class="badge">⚡ GPU Accelerated</span>
621
- </div>
622
- </div>
623
- """)
624
 
625
- # Model selection section
626
- with gr.Row():
627
- model_input = gr.Textbox(
628
- label="🤖 Model Selection",
629
- placeholder="Enter HuggingFace model name (e.g., mrrtmob/tts-khm)",
630
- value="mrrtmob/tts-khm",
631
- elem_classes=["smooth-input"],
632
- info="Enter any compatible TTS model from HuggingFace"
633
- )
634
- model_load_btn = gr.Button(
635
- "🔄 Load Model",
636
- elem_classes=["model-button"],
637
- scale=0
638
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
 
 
 
 
 
 
640
  with gr.Row():
641
- # Input Section
642
- with gr.Column(scale=2):
643
- gr.HTML('<div class="glass-card"><h2 class="card-title">📝 Text Input & Configuration</h2>')
644
-
645
  text_input = gr.Textbox(
646
- label="📖 Text to Synthesize",
647
- placeholder="សូមបញ្ចូលអត្ថបទភាសាខ្មែរនៅទីនេះ...",
648
- lines=5,
649
- value="សួស្ដី ខ្ញុំគឺជា AI អាចនិយាយភាសាខ្មែរបាន",
650
- elem_classes=["smooth-input"]
651
  )
652
-
653
- with gr.Row():
654
- voice_dropdown = gr.Dropdown(
655
- label="🎭 Voice Model",
656
- choices=["Elise", "Jing", "Default"],
657
- value="Elise",
658
- elem_classes=["smooth-dropdown"],
659
- info="Select your preferred voice character"
660
- )
661
-
662
- with gr.Accordion("⚙️ Advanced Parameters", open=False, elem_classes=["accordion-container"]):
663
- gr.HTML('<div style="padding: 1rem;">')
664
- with gr.Row():
665
- temperature = gr.Slider(
666
- minimum=0.1,
667
- maximum=1.0,
668
- value=0.6,
669
- step=0.1,
670
- label="🌡️ Temperature",
671
- info="Controls randomness (0.1 = consistent, 1.0 = creative)",
672
- elem_classes=["smooth-slider"]
673
- )
674
- top_p = gr.Slider(
675
- minimum=0.1,
676
- maximum=1.0,
677
- value=0.95,
678
- step=0.05,
679
- label="🎯 Top P",
680
- info="Controls diversity (0.1 = focused, 1.0 = diverse)",
681
- elem_classes=["smooth-slider"]
682
- )
683
- gr.HTML('</div>')
684
-
685
- generate_btn = gr.Button(
686
- "🎵 Generate Speech",
687
- size="lg",
688
- elem_classes=["generate-button"]
689
  )
690
- gr.HTML('</div>')
691
-
692
- # Output Section
693
- with gr.Column(scale=1):
694
- gr.HTML('<div class="glass-card"><h2 class="card-title">🔊 Audio Output</h2>')
695
 
696
- status_text = gr.Textbox(
697
- label="📊 System Status",
698
- value="🔄 Ready to load model...",
699
- interactive=False,
700
- elem_classes=["status-display"]
701
- )
702
-
703
- audio_output = gr.Audio(
704
- label="🎵 Generated Speech",
705
- type="numpy",
706
- elem_classes=["audio-container"]
707
- )
 
 
 
 
 
 
 
 
 
708
 
709
- gr.HTML("""
710
- <div style="background: rgba(255, 255, 255, 0.2); backdrop-filter: blur(10px);
711
- border-radius: 12px; padding: 1rem; margin-top: 1rem; text-align: center;">
712
- <h4 style="margin: 0 0 0.5rem 0; color: #2d3748;">💡 Quick Tips</h4>
713
- <p style="margin: 0; font-size: 0.875rem; color: #4a5568; line-height: 1.5;">
714
- 🎧 Use headphones for optimal experience<br>
715
- ⚡ Processing typically takes 15-45 seconds<br>
716
- 🔧 Adjust parameters for different results
717
- </p>
718
- </div>
719
- """)
720
- gr.HTML('</div>')
721
-
722
- # Event handlers
723
- model_load_btn.click(
724
- fn=change_model,
725
- inputs=[model_input],
726
- outputs=[status_text]
727
  )
728
 
729
- generate_btn.click(
730
- fn=text_to_speech,
731
- inputs=[text_input, voice_dropdown, temperature, top_p],
732
- outputs=[audio_output, status_text]
 
733
  )
734
-
735
- # Initialize with default model
736
- demo.load(
737
- fn=lambda: initialize_models("mrrtmob/tts-khm"),
738
- outputs=[status_text]
739
- )
740
-
741
- # Enhanced Examples Section
742
- gr.HTML("""
743
- <div class="glass-card" style="margin-top: 2rem;">
744
- <h2 class="card-title">📚 Example Texts</h2>
745
- <p style="color: #4a5568; margin-bottom: 1rem;">Click any example below to try it instantly!</p>
746
- </div>
747
- """)
748
 
749
- with gr.Row():
750
- with gr.Column():
751
- gr.Examples(
752
- examples=[
753
- # Basic greetings
754
- "សួស្ដី អ្នកសុខសប្បាយទេ? ខ្ញុំគឺជា AI",
755
- "ជំរាបសួរ សូមស្វាគមន៍មកកាន់ប្រព័ន្ធ TTS",
756
-
757
- # Cultural content
758
- "ប្រទេសកម្ពុជាមានប្រាសាទអង្គរវត្តដ៏ល្បី",
759
- "បុណ្យចូលឆ្នាំខ្មែរគឺជាបុណ្យធំបំផុត",
760
-
761
- # Educational
762
- "ការអប់រំគឺជាមូលដ្ឋានសំខាន់នៃការអភិវឌ្ឍន៍",
763
- "បច្ចេកវិទ្យាកំពុងផ្លាស់ប្ដូរពិភពលោក",
764
- ],
765
- inputs=[text_input],
766
- label="🌟 Popular Examples"
767
- )
768
-
769
- with gr.Column():
770
- gr.Examples(
771
- examples=[
772
- # Technology
773
- "ការរៀនម៉ាស៊ីននិង AI កំពុងរីកចម្រើន",
774
- "បណ្ដាញសង្គមបានផ្លាស់ប្ដូរជីវិតយើង",
775
-
776
- # Literature
777
- "ព្រះអាទិត្យរះនៅពេលព្រឹក ធ្វើឱ្យផ្ទៃទឹកស្រស់ស្អាត",
778
- "ក្រុមសត្វស្លាបបានហោះហើរនៅលំអង",
779
-
780
- # Information
781
- "ការពារបរិស្ថានគឺជាទំនួលខុសត្រូវរួម",
782
- "ព័ត៌មានគឺជាកម្លាំងនៃការអភិវឌ្ឍន៍",
783
- ],
784
- inputs=[text_input],
785
- label="🎭 Creative Examples"
786
- )
787
-
788
- # Enhanced Information Section
789
- gr.HTML("""
790
- <div class="glass-card" style="margin-top: 2rem;">
791
- <h2 class="card-title">📊 System Information & Guidelines</h2>
792
- <div class="info-grid">
793
- <div class="info-item">
794
- <div class="info-title">🔧 System Status</div>
795
- <div class="info-content">
796
- <strong>Unsloth:</strong> """ + ('✅ Available' if UNSLOTH_AVAILABLE else '❌ Not Available') + """<br>
797
- <strong>SNAC:</strong> """ + ('✅ Available' if SNAC_AVAILABLE else '❌ Not Available') + """<br>
798
- <strong>GPU:</strong> """ + ('✅ Available' if torch.cuda.is_available() else '❌ CPU Only') + """<br>
799
- <strong>Device:</strong> """ + ('CUDA' if torch.cuda.is_available() else 'CPU') + """
800
- </div>
801
- </div>
802
- <div class="info-item">
803
- <div class="info-title">🎭 Voice Profiles</div>
804
- <div class="info-content">
805
- <strong>Elise:</strong> Clear, professional, news-style<br>
806
- <strong>Jing:</strong> Warm, conversational, friendly<br>
807
- <strong>Default:</strong> Standard neutral synthesis<br>
808
- <em>Each voice has unique characteristics</em>
809
- </div>
810
- </div>
811
- <div class="info-item">
812
- <div class="info-title">🤖 Model Support</div>
813
- <div class="info-content">
814
- <strong>Current:</strong> mrrtmob/tts-khm (default)<br>
815
- <strong>Custom:</strong> Any HuggingFace TTS model<br>
816
- <strong>Format:</strong> username/model-name<br>
817
- <em>Models are cached after first load</em>
818
- </div>
819
- </div>
820
- <div class="info-item">
821
- <div class="info-title">💡 Best Practices</div>
822
- <div class="info-content">
823
- • Use proper Khmer Unicode text<br>
824
- • Keep sentences under 100 characters<br>
825
- • Lower temperature = more consistent<br>
826
- • Higher Top P = more natural variation<br>
827
- • Test different voice models for variety
828
- </div>
829
- </div>
830
- <div class="info-item">
831
- <div class="info-title">⚡ Performance Tips</div>
832
- <div class="info-content">
833
- • GPU acceleration automatically detected<br>
834
- • Models are loaded once and cached<br>
835
- • First generation may take longer<br>
836
- • SNAC decoding optimized for memory<br>
837
- • Batch processing not yet supported
838
- </div>
839
- </div>
840
- <div class="info-item">
841
- <div class="info-title">🔧 Technical Details</div>
842
- <div class="info-content">
843
- <strong>Sample Rate:</strong> 24 kHz<br>
844
- <strong>Format:</strong> WAV (numpy array)<br>
845
- <strong>Max Tokens:</strong> 1200 new tokens<br>
846
- <strong>Sequence Length:</strong> 2048 tokens<br>
847
- <strong>Audio Quality:</strong> High-fidelity
848
- </div>
849
- </div>
850
- </div>
851
- </div>
852
- """)
853
-
854
- # Footer
855
- gr.HTML("""
856
- <div style="text-align: center; margin-top: 2rem; padding: 2rem;
857
- background: rgba(255, 255, 255, 0.1); backdrop-filter: blur(10px);
858
- border-radius: 16px; border: 1px solid rgba(255, 255, 255, 0.2);">
859
- <h3 style="color: #2d3748; margin-bottom: 1rem;">🌟 Advanced Khmer TTS Studio</h3>
860
- <p style="color: #4a5568; margin: 0; font-size: 0.875rem;">
861
- Built with ❤️ for the Khmer community • Powered by state-of-the-art AI<br>
862
- Supporting multiple models • Professional-grade speech synthesis
863
- </p>
864
- </div>
865
- """)
866
-
867
  if __name__ == "__main__":
868
- demo.launch(
869
- server_name="0.0.0.0",
870
- server_port=7860,
871
- show_api=False,
872
- share=False,
873
- favicon_path=None,
874
- ssl_verify=False,
875
- inbrowser=True
876
- )
 
1
+ import spaces
2
+ from snac import SNAC
3
  import torch
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from huggingface_hub import snapshot_download
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+ # Check if CUDA is available
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print("Loading SNAC model...")
12
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
+ snac_model = snac_model.to(device)
14
+ model_name = "mrrtmob/tts-khm"
15
+ # Download only model config and safetensors
16
+ snapshot_download(
17
+ repo_id=model_name,
18
+ allow_patterns=[
19
+ "config.json",
20
+ "*.safetensors",
21
+ "model.safetensors.index.json",
22
+ ],
23
+ ignore_patterns=[
24
+ "optimizer.pt",
25
+ "pytorch_model.bin",
26
+ "training_args.bin",
27
+ "scheduler.pt",
28
+ "tokenizer.json",
29
+ "tokenizer_config.json",
30
+ "special_tokens_map.json",
31
+ "vocab.json",
32
+ "merges.txt",
33
+ "tokenizer.*"
34
+ ]
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
+ model.to(device)
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ print(f"Khmer TTS model loaded to {device}")
40
+ # Process text prompt
41
+ def process_prompt(prompt, voice, tokenizer, device):
42
+ prompt = f"{voice}: {prompt}"
43
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
46
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
 
 
 
 
 
 
 
47
 
48
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
 
 
 
 
 
 
49
 
50
+ # No padding needed for single input
51
+ attention_mask = torch.ones_like(modified_input_ids)
52
 
53
+ return modified_input_ids.to(device), attention_mask.to(device)
54
+ # Parse output tokens to audio
55
+ def parse_output(generated_ids):
56
+ token_to_find = 128257
57
+ token_to_remove = 128258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
60
+ if len(token_indices[1]) > 0:
61
+ last_occurrence_idx = token_indices[1][-1].item()
62
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
+ else:
64
+ cropped_tensor = generated_ids
65
+ processed_rows = []
66
+ for row in cropped_tensor:
67
+ masked_row = row[row != token_to_remove]
68
+ processed_rows.append(masked_row)
69
+ code_lists = []
70
+ for row in processed_rows:
71
+ row_length = row.size(0)
72
+ new_length = (row_length // 7) * 7
73
+ trimmed_row = row[:new_length]
74
+ trimmed_row = [t - 128266 for t in trimmed_row]
75
+ code_lists.append(trimmed_row)
76
+
77
+ return code_lists[0] # Return just the first one for single sample
78
+ # Redistribute codes for audio generation
79
+ def redistribute_codes(code_list, snac_model):
80
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
81
 
82
+ layer_1 = []
83
+ layer_2 = []
84
+ layer_3 = []
85
+ for i in range((len(code_list)+1)//7):
86
+ layer_1.append(code_list[7*i])
87
+ layer_2.append(code_list[7*i+1]-4096)
88
+ layer_3.append(code_list[7*i+2]-(2*4096))
89
+ layer_3.append(code_list[7*i+3]-(3*4096))
90
+ layer_2.append(code_list[7*i+4]-(4*4096))
91
+ layer_3.append(code_list[7*i+5]-(5*4096))
92
+ layer_3.append(code_list[7*i+6]-(6*4096))
93
+
94
+ # Move tensors to the same device as the SNAC model
95
+ codes = [
96
+ torch.tensor(layer_1, device=device).unsqueeze(0),
97
+ torch.tensor(layer_2, device=device).unsqueeze(0),
98
+ torch.tensor(layer_3, device=device).unsqueeze(0)
99
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ audio_hat = snac_model.decode(codes)
102
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
103
+ # Main generation function
104
+ @spaces.GPU()
105
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
106
+ if not text.strip():
107
+ return None
 
 
 
 
 
 
108
 
109
+ try:
110
+ progress(0.1, "Processing text...")
111
+ input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
112
+
113
+ progress(0.3, "Generating speech tokens...")
114
+ with torch.no_grad():
115
+ generated_ids = model.generate(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ max_new_tokens=max_new_tokens,
119
+ do_sample=True,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ repetition_penalty=repetition_penalty,
123
+ num_return_sequences=1,
124
+ eos_token_id=128258,
125
+ )
126
+
127
+ progress(0.6, "Processing speech tokens...")
128
+ code_list = parse_output(generated_ids)
129
+
130
+ progress(0.8, "Converting to audio...")
131
+ audio_samples = redistribute_codes(code_list, snac_model)
132
+
133
+ return (24000, audio_samples) # Return sample rate and audio
134
+ except Exception as e:
135
+ print(f"Error generating speech: {e}")
136
+ return None
137
+ # Examples for the UI - Khmer text examples
138
+ examples = [
139
+ ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
140
+ ["ខ្ញុំអាចបង្កើតសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច ឬ <sigh> ថប់ដង្ហើម។", "dan", 0.7, 0.95, 1.1, 1200],
141
+ ["ខ្ញុំរស់នៅក្នុងទីក្រុងភ្នំពេញ ហើយមានប៉ារ៉ាម៉ែត្រ <gasp> ច្រើនណាស់។", "leah", 0.6, 0.9, 1.2, 1200],
142
+ ["ពេលខ្លះ ពេលខ្ញុំនិយាយច្រើនព���ក ខ្ញុំត្រូវ <cough> សុំទោស។", "leo", 0.65, 0.9, 1.1, 1200],
143
+ ["ការនិយាយនៅចំពោះមុខសាធារណៈ អាចមានការពិបាក។ <groan> ប៉ុន្តែបើហាត់ហាន គេអាចធ្វើបាន។", "jess", 0.7, 0.95, 1.1, 1200],
144
+ ["ការឡើងភ្នំពិតជាហត់ណត់ ប៉ុន្តែទេសភាពពីលើនេះ ពិតជាស្រស់ស្អាត! <sigh> គួរឱ្យធ្វើ។", "mia", 0.65, 0.9, 1.15, 1200],
145
+ ["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
146
+ ["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
147
+ ]
148
+ # Available voices
149
+ VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
150
+ # Available Emotive Tags
151
+ EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
152
+ # Create Gradio interface
153
+ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
154
+ gr.Markdown(f"""
155
+ # 🎵 Khmer Text-to-Speech (ម៉ូដែលបម្លែងអត្ថបទជាសំលេង)
156
+ Enter your Khmer text below and hear it converted to natural-sounding speech.
157
+
158
+ បញ្ចូលអត្ថបទខ្មែររបស់អ្នកខាងក្រោម ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយធម្មជាតិ។
159
 
160
+ ## Tips for better prompts (គន្លឹះសម្រាប់ការប្រើប្រាស់ដ៏ល្អ):
161
+ - Add paralinguistic elements like {", ".join(EMOTIVE_TAGS)} for more human-like speech
162
+ - Longer text prompts generally work better than very short phrases
163
+ - អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
164
+ - Increasing `repetition_penalty` and `temperature` makes the model speak faster
165
+ """)
166
  with gr.Row():
167
+ with gr.Column(scale=3):
 
 
 
168
  text_input = gr.Textbox(
169
+ label="Text to speak (អត្ថបទដើម្បីនិយាយ)",
170
+ placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
171
+ lines=5
 
 
172
  )
173
+ voice = gr.Dropdown(
174
+ choices=VOICES,
175
+ value="tara",
176
+ label="Voice (សំលេង)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
 
 
 
 
 
178
 
179
+ with gr.Accordion("Advanced Settings (ការកំណត់កម្រិតខ្ពស់)", open=False):
180
+ temperature = gr.Slider(
181
+ minimum=0.1, maximum=1.5, value=0.6, step=0.05,
182
+ label="Temperature",
183
+ info="Higher values (0.7-1.0) create more expressive but less stable speech"
184
+ )
185
+ top_p = gr.Slider(
186
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
187
+ label="Top P",
188
+ info="Nucleus sampling threshold"
189
+ )
190
+ repetition_penalty = gr.Slider(
191
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
192
+ label="Repetition Penalty",
193
+ info="Higher values discourage repetitive patterns"
194
+ )
195
+ max_new_tokens = gr.Slider(
196
+ minimum=100, maximum=2000, value=1200, step=100,
197
+ label="Max Length",
198
+ info="Maximum length of generated audio (in tokens)"
199
+ )
200
 
201
+ with gr.Row():
202
+ submit_btn = gr.Button("Generate Speech (បង្កើតសំលេង)", variant="primary")
203
+ clear_btn = gr.Button("Clear (លុប)")
204
+
205
+ with gr.Column(scale=2):
206
+ audio_output = gr.Audio(label="Generated Speech (សំលេងដែលបង្កើតឡើង)", type="numpy")
207
+
208
+ # Set up examples
209
+ gr.Examples(
210
+ examples=examples,
211
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
212
+ outputs=audio_output,
213
+ fn=generate_speech,
214
+ cache_examples=True,
 
 
 
 
215
  )
216
 
217
+ # Set up event handlers
218
+ submit_btn.click(
219
+ fn=generate_speech,
220
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
221
+ outputs=audio_output
222
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ clear_btn.click(
225
+ fn=lambda: (None, None),
226
+ inputs=[],
227
+ outputs=[text_input, audio_output]
228
+ )
229
+ # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if __name__ == "__main__":
231
+ demo.queue().launch(share=False, ssr_mode=False)