mrrtmob commited on
Commit
8bdeb28
Β·
1 Parent(s): 6c2bc94
Files changed (1) hide show
  1. app.py +114 -335
app.py CHANGED
@@ -3,268 +3,116 @@ from snac import SNAC
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download, login
7
  from dotenv import load_dotenv
8
- import os
9
- import re
10
- import numpy as np
11
-
12
  load_dotenv()
13
-
14
- # Setup Hugging Face authentication
15
- def setup_auth():
16
- hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
17
- if hf_token:
18
- try:
19
- login(token=hf_token, add_to_git_credential=False)
20
- print("βœ… Successfully logged in to Hugging Face")
21
- return True
22
- except Exception as e:
23
- print(f"⚠️ Failed to login to Hugging Face: {e}")
24
- return False
25
- else:
26
- print("⚠️ No HF token found. Running as anonymous user.")
27
- return False
28
-
29
- # Setup authentication before anything else
30
- auth_success = setup_auth()
31
-
32
  # Check if CUDA is available
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
- print(f"Using device: {device}")
35
- print(f"Authentication status: {'βœ… Logged in' if auth_success else '❌ Anonymous'}")
36
-
37
- # Global variables to store models
38
- snac_model = None
39
- model = None
40
- tokenizer = None
41
-
42
- def load_models():
43
- global snac_model, model, tokenizer
44
-
45
- print("Loading SNAC model...")
46
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
47
- snac_model = snac_model.to(device)
48
-
49
- model_name = "mrrtmob/tts-khm-4"
50
-
51
- print("Downloading model files...")
52
- snapshot_download(
53
- repo_id=model_name,
54
- allow_patterns=[
55
- "config.json",
56
- "*.safetensors",
57
- "model.safetensors.index.json",
58
- "tokenizer.json",
59
- "tokenizer_config.json",
60
- "special_tokens_map.json",
61
- "vocab.json",
62
- "merges.txt"
63
- ],
64
- ignore_patterns=[
65
- "optimizer.pt",
66
- "pytorch_model.bin",
67
- "training_args.bin",
68
- "scheduler.pt"
69
- ]
70
- )
71
-
72
- print("Loading main model...")
73
- if device == "cuda":
74
- model = AutoModelForCausalLM.from_pretrained(
75
- model_name,
76
- torch_dtype=torch.bfloat16,
77
- low_cpu_mem_usage=True
78
- )
79
- model = model.to(device)
80
- else:
81
- model = AutoModelForCausalLM.from_pretrained(
82
- model_name,
83
- torch_dtype=torch.float32
84
- )
85
-
86
- print("Loading tokenizer...")
87
- tokenizer = AutoTokenizer.from_pretrained(model_name)
88
-
89
- if tokenizer.pad_token is None:
90
- tokenizer.pad_token = tokenizer.eos_token
91
-
92
- print(f"Khmer TTS model loaded to {device}")
93
-
94
- # Load models at startup
95
- load_models()
96
-
97
- def split_text_by_punctuation(text, max_chars=200):
98
- sentence_endings = r'[αŸ”!?]'
99
- clause_separators = r'[,;:]'
100
-
101
- sentences = re.split(f'({sentence_endings})', text)
102
-
103
- combined_sentences = []
104
- for i in range(0, len(sentences), 2):
105
- sentence = sentences[i]
106
- if i + 1 < len(sentences):
107
- sentence += sentences[i + 1]
108
- if sentence.strip():
109
- combined_sentences.append(sentence.strip())
110
-
111
- if len(combined_sentences) <= 1:
112
- parts = re.split(f'({clause_separators})', text)
113
- combined_sentences = []
114
- for i in range(0, len(parts), 2):
115
- part = parts[i]
116
- if i + 1 < len(parts):
117
- part += parts[i + 1]
118
- if part.strip():
119
- combined_sentences.append(part.strip())
120
-
121
- final_chunks = []
122
- for sentence in combined_sentences:
123
- if len(sentence) <= max_chars:
124
- final_chunks.append(sentence)
125
- else:
126
- words = sentence.split()
127
- current_chunk = ""
128
-
129
- for word in words:
130
- test_chunk = current_chunk + " " + word if current_chunk else word
131
- if len(test_chunk) <= max_chars:
132
- current_chunk = test_chunk
133
- else:
134
- if current_chunk:
135
- final_chunks.append(current_chunk)
136
- current_chunk = word
137
-
138
- if current_chunk:
139
- final_chunks.append(current_chunk)
140
-
141
- return [chunk for chunk in final_chunks if chunk.strip()]
142
-
143
- def split_text_by_tokens(text, max_tokens=150):
144
- global tokenizer
145
-
146
- tokens = tokenizer.encode(text)
147
-
148
- if len(tokens) <= max_tokens:
149
- return [text]
150
-
151
- chunks = []
152
- words = text.split()
153
- current_chunk = ""
154
-
155
- for word in words:
156
- test_chunk = current_chunk + " " + word if current_chunk else word
157
- test_tokens = tokenizer.encode(test_chunk)
158
-
159
- if len(test_tokens) <= max_tokens:
160
- current_chunk = test_chunk
161
- else:
162
- if current_chunk:
163
- chunks.append(current_chunk)
164
- current_chunk = word
165
-
166
- if current_chunk:
167
- chunks.append(current_chunk)
168
-
169
- return chunks
170
-
171
  def process_prompt(prompt, voice, tokenizer, device):
172
  prompt = f"{voice}: {prompt}"
173
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
174
- start_token = torch.tensor([[128259]], dtype=torch.int64)
175
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
176
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
 
 
 
 
177
  attention_mask = torch.ones_like(modified_input_ids)
 
178
  return modified_input_ids.to(device), attention_mask.to(device)
179
-
180
  def parse_output(generated_ids):
181
  token_to_find = 128257
182
  token_to_remove = 128258
183
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
184
 
 
185
  if len(token_indices[1]) > 0:
186
  last_occurrence_idx = token_indices[1][-1].item()
187
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
188
  else:
189
  cropped_tensor = generated_ids
190
-
191
  processed_rows = []
192
  for row in cropped_tensor:
193
  masked_row = row[row != token_to_remove]
194
  processed_rows.append(masked_row)
195
-
196
  code_lists = []
197
  for row in processed_rows:
198
  row_length = row.size(0)
199
  new_length = (row_length // 7) * 7
200
  trimmed_row = row[:new_length]
201
- trimmed_row = [max(0, t - 128266) for t in trimmed_row]
202
  code_lists.append(trimmed_row)
203
-
204
- return code_lists[0] if code_lists and len(code_lists[0]) > 0 else []
205
-
206
- def redistribute_codes(code_list, snac_model):
207
- if not code_list or len(code_list) < 7:
208
- return np.zeros(12000)
209
 
210
- device = next(snac_model.parameters()).device
 
 
 
 
211
  layer_1 = []
212
  layer_2 = []
213
  layer_3 = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  try:
216
- for i in range((len(code_list))//7):
217
- layer_1.append(max(0, code_list[7*i]))
218
- layer_2.append(max(0, code_list[7*i+1]-4096))
219
- layer_3.append(max(0, code_list[7*i+2]-(2*4096)))
220
- layer_3.append(max(0, code_list[7*i+3]-(3*4096)))
221
- layer_2.append(max(0, code_list[7*i+4]-(4*4096)))
222
- layer_3.append(max(0, code_list[7*i+5]-(5*4096)))
223
- layer_3.append(max(0, code_list[7*i+6]-(6*4096)))
224
-
225
- codes = [
226
- torch.tensor(layer_1, device=device).unsqueeze(0),
227
- torch.tensor(layer_2, device=device).unsqueeze(0),
228
- torch.tensor(layer_3, device=device).unsqueeze(0)
229
- ]
230
-
231
- with torch.no_grad():
232
- audio_hat = snac_model.decode(codes)
233
- return audio_hat.detach().squeeze().cpu().numpy()
234
- except Exception as e:
235
- print(f"Error in redistribute_codes: {e}")
236
- return np.zeros(12000)
237
-
238
- def combine_audio_chunks(audio_chunks, pause_duration=0.3):
239
- if not audio_chunks:
240
- return np.array([])
241
-
242
- pause_samples = int(24000 * pause_duration)
243
- pause = np.zeros(pause_samples)
244
-
245
- combined_audio = []
246
- for i, chunk in enumerate(audio_chunks):
247
- if len(chunk) > 0:
248
- combined_audio.append(chunk)
249
- if i < len(audio_chunks) - 1:
250
- combined_audio.append(pause)
251
-
252
- if combined_audio:
253
- return np.concatenate(combined_audio)
254
- else:
255
- return np.array([])
256
-
257
- @spaces.GPU(duration=60) # Reduced duration to be more conservative
258
- def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600, voice="Elise"):
259
- """Generate speech for a single chunk"""
260
- global model, tokenizer, snac_model
261
-
262
- if not text_chunk.strip():
263
- return np.array([])
264
-
265
- try:
266
- input_ids, attention_mask = process_prompt(text_chunk, voice, tokenizer, device)
267
 
 
268
  with torch.no_grad():
269
  generated_ids = model.generate(
270
  input_ids=input_ids,
@@ -276,143 +124,77 @@ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_pe
276
  repetition_penalty=repetition_penalty,
277
  num_return_sequences=1,
278
  eos_token_id=128258,
279
- pad_token_id=tokenizer.eos_token_id,
280
- use_cache=True
281
  )
282
 
 
283
  code_list = parse_output(generated_ids)
284
 
285
- if not code_list:
286
- return np.array([])
287
-
288
  audio_samples = redistribute_codes(code_list, snac_model)
289
- return audio_samples
290
-
291
- except Exception as e:
292
- print(f"Error generating speech chunk: {e}")
293
- return np.array([])
294
-
295
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600,
296
- voice="Elise", split_method="punctuation", max_chars=150, max_tokens=100,
297
- pause_duration=0.3, progress=gr.Progress()):
298
- """Main function to generate speech with text splitting"""
299
-
300
- if not text.strip():
301
- return None
302
-
303
- try:
304
- progress(0.05, "Splitting text...")
305
-
306
- if split_method == "punctuation":
307
- text_chunks = split_text_by_punctuation(text, max_chars)
308
- elif split_method == "tokens":
309
- text_chunks = split_text_by_tokens(text, max_tokens)
310
- else:
311
- text_chunks = [text]
312
-
313
- progress(0.1, f"Processing {len(text_chunks)} chunks...")
314
- print(f"Split text into {len(text_chunks)} chunks:")
315
- for i, chunk in enumerate(text_chunks):
316
- print(f"Chunk {i+1}: {chunk[:50]}...")
317
-
318
- audio_chunks = []
319
- for i, chunk in enumerate(text_chunks):
320
- progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
321
-
322
- audio = generate_speech_chunk(
323
- chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
324
- )
325
-
326
- if len(audio) > 0:
327
- audio_chunks.append(audio)
328
- print(f"Generated audio for chunk {i+1}: {len(audio)} samples ({len(audio)/24000:.2f}s)")
329
-
330
- if not audio_chunks:
331
- return None
332
-
333
- progress(0.9, "Combining audio chunks...")
334
- final_audio = combine_audio_chunks(audio_chunks, pause_duration)
335
-
336
- progress(1.0, "Complete!")
337
- print(f"Final audio: {len(final_audio)} samples ({len(final_audio)/24000:.2f}s)")
338
-
339
- return (24000, final_audio)
340
 
 
341
  except Exception as e:
342
  print(f"Error generating speech: {e}")
343
- import traceback
344
- traceback.print_exc()
345
  return None
346
-
347
- # [Rest of your Gradio interface code remains the same]
348
  examples = [
349
- ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ αžαžΆαžšαžΆαŸ” αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
350
  ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
 
 
 
351
  ]
352
-
 
 
353
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
354
-
355
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
356
  gr.Markdown(f"""
357
  # 🎡 Khmer Text-to-Speech
358
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
359
- Authentication: {'βœ… Pro Account' if auth_success else '❌ Anonymous (Limited)'}
360
 
361
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
 
362
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
363
- ✨ **New**: Supports long text with automatic splitting!
364
- """)
365
 
366
  text_input = gr.Textbox(
367
- label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
368
- placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžΆαž…αžœαŸ‚αž„αž”αžΆαž“)",
369
- lines=6
370
  )
371
 
372
- with gr.Accordion("πŸ“ Text Splitting Options", open=True):
373
- split_method = gr.Radio(
374
- choices=[
375
- ("Split by punctuation (recommended)", "punctuation"),
376
- ("Split by token count", "tokens"),
377
- ("No splitting", "none")
378
- ],
379
- value="punctuation",
380
- label="Text splitting method"
381
- )
382
-
383
- with gr.Row():
384
- max_chars = gr.Slider(
385
- minimum=50, maximum=300, value=150, step=25,
386
- label="Max characters per chunk"
387
- )
388
- max_tokens = gr.Slider(
389
- minimum=50, maximum=200, value=100, step=25,
390
- label="Max tokens per chunk"
391
- )
392
-
393
- pause_duration = gr.Slider(
394
- minimum=0.0, maximum=1.0, value=0.3, step=0.1,
395
- label="Pause between chunks (seconds)"
396
- )
397
 
 
398
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
399
  with gr.Row():
400
  temperature = gr.Slider(
401
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
402
- label="Temperature"
 
403
  )
404
  top_p = gr.Slider(
405
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
406
- label="Top P"
 
407
  )
408
  with gr.Row():
409
  repetition_penalty = gr.Slider(
410
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
411
- label="Repetition Penalty"
 
412
  )
413
  max_new_tokens = gr.Slider(
414
- minimum=100, maximum=800, value=600, step=100,
415
- label="Max tokens per chunk"
 
416
  )
417
 
418
  with gr.Row():
@@ -420,11 +202,12 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
420
  clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
421
 
422
  audio_output = gr.Audio(
423
- label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
424
  type="numpy",
425
  show_label=True
426
  )
427
 
 
428
  gr.Examples(
429
  examples=examples,
430
  inputs=[text_input],
@@ -433,10 +216,10 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
433
  cache_examples=False,
434
  )
435
 
 
436
  submit_btn.click(
437
  fn=generate_speech,
438
- inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens,
439
- gr.State("Elise"), split_method, max_chars, max_tokens, pause_duration],
440
  outputs=audio_output
441
  )
442
 
@@ -445,10 +228,6 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
445
  inputs=[],
446
  outputs=[text_input, audio_output]
447
  )
448
-
449
  if __name__ == "__main__":
450
- demo.queue(max_size=5).launch(
451
- share=False,
452
- server_name="0.0.0.0",
453
- server_port=7860
454
- )
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
 
 
 
 
8
  load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Check if CUDA is available
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print("Loading SNAC model...")
12
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
+ snac_model = snac_model.to(device)
14
+ model_name = "mrrtmob/tts-khm-kore"
15
+ # Download only model config and safetensors
16
+ snapshot_download(
17
+ repo_id=model_name,
18
+ allow_patterns=[
19
+ "config.json",
20
+ "*.safetensors",
21
+ "model.safetensors.index.json",
22
+ ],
23
+ ignore_patterns=[
24
+ "optimizer.pt",
25
+ "pytorch_model.bin",
26
+ "training_args.bin",
27
+ "scheduler.pt",
28
+ "tokenizer.json",
29
+ "tokenizer_config.json",
30
+ "special_tokens_map.json",
31
+ "vocab.json",
32
+ "merges.txt",
33
+ "tokenizer.*"
34
+ ]
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
+ model.to(device)
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ print(f"Khmer TTS model loaded to {device}")
40
+ # Process text prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def process_prompt(prompt, voice, tokenizer, device):
42
  prompt = f"{voice}: {prompt}"
43
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
+
45
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
46
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
47
+
48
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
49
+
50
+ # No padding needed for single input
51
  attention_mask = torch.ones_like(modified_input_ids)
52
+
53
  return modified_input_ids.to(device), attention_mask.to(device)
54
+ # Parse output tokens to audio
55
  def parse_output(generated_ids):
56
  token_to_find = 128257
57
  token_to_remove = 128258
 
58
 
59
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
60
  if len(token_indices[1]) > 0:
61
  last_occurrence_idx = token_indices[1][-1].item()
62
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
  else:
64
  cropped_tensor = generated_ids
65
+
66
  processed_rows = []
67
  for row in cropped_tensor:
68
  masked_row = row[row != token_to_remove]
69
  processed_rows.append(masked_row)
70
+
71
  code_lists = []
72
  for row in processed_rows:
73
  row_length = row.size(0)
74
  new_length = (row_length // 7) * 7
75
  trimmed_row = row[:new_length]
76
+ trimmed_row = [t - 128266 for t in trimmed_row]
77
  code_lists.append(trimmed_row)
 
 
 
 
 
 
78
 
79
+ return code_lists[0] # Return just the first one for single sample
80
+ # Redistribute codes for audio generation
81
+ def redistribute_codes(code_list, snac_model):
82
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
83
+
84
  layer_1 = []
85
  layer_2 = []
86
  layer_3 = []
87
+ for i in range((len(code_list)+1)//7):
88
+ layer_1.append(code_list[7*i])
89
+ layer_2.append(code_list[7*i+1]-4096)
90
+ layer_3.append(code_list[7*i+2]-(2*4096))
91
+ layer_3.append(code_list[7*i+3]-(3*4096))
92
+ layer_2.append(code_list[7*i+4]-(4*4096))
93
+ layer_3.append(code_list[7*i+5]-(5*4096))
94
+ layer_3.append(code_list[7*i+6]-(6*4096))
95
+
96
+ # Move tensors to the same device as the SNAC model
97
+ codes = [
98
+ torch.tensor(layer_1, device=device).unsqueeze(0),
99
+ torch.tensor(layer_2, device=device).unsqueeze(0),
100
+ torch.tensor(layer_3, device=device).unsqueeze(0)
101
+ ]
102
+
103
+ audio_hat = snac_model.decode(codes)
104
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
105
+ # Main generation function
106
+ @spaces.GPU()
107
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
108
+ if not text.strip():
109
+ return None
110
 
111
  try:
112
+ progress(0.1, "Processing text...")
113
+ input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ progress(0.3, "Generating speech tokens...")
116
  with torch.no_grad():
117
  generated_ids = model.generate(
118
  input_ids=input_ids,
 
124
  repetition_penalty=repetition_penalty,
125
  num_return_sequences=1,
126
  eos_token_id=128258,
 
 
127
  )
128
 
129
+ progress(0.6, "Processing speech tokens...")
130
  code_list = parse_output(generated_ids)
131
 
132
+ progress(0.8, "Converting to audio...")
 
 
133
  audio_samples = redistribute_codes(code_list, snac_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ return (24000, audio_samples) # Return sample rate and audio
136
  except Exception as e:
137
  print(f"Error generating speech: {e}")
 
 
138
  return None
139
+ # Examples for the UI - Khmer text examples
 
140
  examples = [
141
+ ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
142
  ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
143
+ ["αžαŸ’αž‰αž»αŸ†αžšαžŸαŸ‹αž“αŸ…αž€αŸ’αž“αž»αž„αž‘αžΈαž€αŸ’αžšαž»αž„αž—αŸ’αž“αŸ†αž–αŸαž‰ αž αžΎαž™αž˜αžΆαž“αž…αžšαžΆαž…αžšαžŽαŸ <gasp> αž…αŸ’αžšαžΎαž“αžŽαžΆαžŸαŸ‹αŸ”"],
144
+ ["αž–αŸαž›αžαŸ’αž›αŸ‡ αž–αŸαž›αžαŸ’αž‰αž»αŸ†αž“αž·αž™αžΆαž™αž…αŸ’αžšαžΎαž“αž–αŸαž€ αžαŸ’αž‰αž»αŸ†αžαŸ’αžšαžΌαžœ <cough> αžŸαž»αŸ†αž‘αŸ„αžŸαŸ”"],
145
+ ["αž€αžΆαžšαž“αž·αž™αžΆαž™αž“αŸ…αž…αŸ†αž–αŸ„αŸ‡αž˜αž»αžαžŸαžΆαž’αžΆαžšαžŽαŸˆ αž’αžΆαž…αž˜αžΆαž“αž€αžΆαžšαž–αž·αž”αžΆαž€αŸ” <groan> αž”αŸ‰αž»αž“αŸ’αžαŸ‚αž”αžΎαž αžΆαžαŸ‹ αž‚αŸαž’αžΆαž…αž’αŸ’αžœαžΎαž”αžΆαž“αŸ”"],
146
  ]
147
+ # Available voices (commented out for simpler UI)
148
+ # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
149
+ # Available Emotive Tags
150
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
151
+ # Create Gradio interface
152
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
153
  gr.Markdown(f"""
154
  # 🎡 Khmer Text-to-Speech
155
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
 
156
 
157
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
158
+
159
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
160
+ """)
 
161
 
162
  text_input = gr.Textbox(
163
+ label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
164
+ placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡...",
165
+ lines=4
166
  )
167
 
168
+ # Voice selector (commented out)
169
+ # voice = gr.Dropdown(
170
+ # choices=VOICES,
171
+ # value="tara",
172
+ # label="Voice (αžŸαŸ†αž›αŸαž„)"
173
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ # Advanced Settings
176
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
177
  with gr.Row():
178
  temperature = gr.Slider(
179
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
180
+ label="Temperature",
181
+ info="Higher values create more expressive speech"
182
  )
183
  top_p = gr.Slider(
184
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
185
+ label="Top P",
186
+ info="Nucleus sampling threshold"
187
  )
188
  with gr.Row():
189
  repetition_penalty = gr.Slider(
190
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
191
+ label="Repetition Penalty",
192
+ info="Higher values discourage repetitive patterns"
193
  )
194
  max_new_tokens = gr.Slider(
195
+ minimum=100, maximum=2000, value=1200, step=100,
196
+ label="Max Length",
197
+ info="Maximum length of generated audio"
198
  )
199
 
200
  with gr.Row():
 
202
  clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
203
 
204
  audio_output = gr.Audio(
205
+ label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
206
  type="numpy",
207
  show_label=True
208
  )
209
 
210
+ # Set up examples (NO CACHE)
211
  gr.Examples(
212
  examples=examples,
213
  inputs=[text_input],
 
216
  cache_examples=False,
217
  )
218
 
219
+ # Set up event handlers
220
  submit_btn.click(
221
  fn=generate_speech,
222
+ inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
 
223
  outputs=audio_output
224
  )
225
 
 
228
  inputs=[],
229
  outputs=[text_input, audio_output]
230
  )
231
+ # Launch the app
232
  if __name__ == "__main__":
233
+ demo.queue().launch(share=False)