mrrtmob commited on
Commit
21ac5de
Β·
1 Parent(s): 2938eff

Add flash-attn to requirements

Browse files
Files changed (2) hide show
  1. app.py +234 -277
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,14 +4,17 @@ import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download, login
7
- from dotenv import load_dotenv
8
  import os
9
  import re
10
  import numpy as np
 
 
11
 
12
- load_dotenv()
 
 
 
13
 
14
- # Setup Hugging Face authentication
15
  def setup_auth():
16
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
17
  if hf_token:
@@ -22,19 +25,14 @@ def setup_auth():
22
  except Exception as e:
23
  print(f"⚠️ Failed to login to Hugging Face: {e}")
24
  return False
25
- else:
26
- print("⚠️ No HF token found. Running as anonymous user.")
27
- return False
28
 
29
- # Setup authentication before anything else
30
  auth_success = setup_auth()
31
 
32
- # Check if CUDA is available
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print(f"Using device: {device}")
35
- print(f"Authentication status: {'βœ… Logged in' if auth_success else '❌ Anonymous'}")
36
 
37
- # Global variables to store models
38
  snac_model = None
39
  model = None
40
  tokenizer = None
@@ -45,227 +43,204 @@ def load_models():
45
  print("Loading SNAC model...")
46
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
47
  snac_model = snac_model.to(device)
 
48
 
49
- model_name = "mrrtmob/tts-khm-4"
 
 
50
 
51
- print("Downloading model files...")
52
- snapshot_download(
53
- repo_id=model_name,
54
- allow_patterns=[
55
- "config.json",
56
- "*.safetensors",
57
- "model.safetensors.index.json",
58
- "tokenizer.json",
59
- "tokenizer_config.json",
60
- "special_tokens_map.json",
61
- "vocab.json",
62
- "merges.txt"
63
- ],
64
- ignore_patterns=[
65
- "optimizer.pt",
66
- "pytorch_model.bin",
67
- "training_args.bin",
68
- "scheduler.pt"
69
- ]
70
- )
71
 
72
  print("Loading main model...")
73
  if device == "cuda":
74
  model = AutoModelForCausalLM.from_pretrained(
75
- model_name,
76
  torch_dtype=torch.bfloat16,
77
- low_cpu_mem_usage=True
 
78
  )
79
  model = model.to(device)
 
 
 
80
  else:
81
  model = AutoModelForCausalLM.from_pretrained(
82
- model_name,
83
  torch_dtype=torch.float32
84
  )
85
 
 
 
86
  print("Loading tokenizer...")
87
  tokenizer = AutoTokenizer.from_pretrained(model_name)
88
-
89
  if tokenizer.pad_token is None:
90
  tokenizer.pad_token = tokenizer.eos_token
91
 
92
- print(f"Khmer TTS model loaded to {device}")
93
 
94
- # Load models at startup
95
  load_models()
96
 
97
- def split_text_by_punctuation(text, max_chars=200):
98
- sentence_endings = r'[αŸ”!?]'
99
- clause_separators = r'[,;:]'
100
-
101
- sentences = re.split(f'({sentence_endings})', text)
102
-
103
- combined_sentences = []
104
- for i in range(0, len(sentences), 2):
105
- sentence = sentences[i]
106
- if i + 1 < len(sentences):
107
- sentence += sentences[i + 1]
108
- if sentence.strip():
109
- combined_sentences.append(sentence.strip())
110
-
111
- if len(combined_sentences) <= 1:
112
- parts = re.split(f'({clause_separators})', text)
113
- combined_sentences = []
114
- for i in range(0, len(parts), 2):
115
- part = parts[i]
116
- if i + 1 < len(parts):
117
- part += parts[i + 1]
118
- if part.strip():
119
- combined_sentences.append(part.strip())
120
-
121
- final_chunks = []
122
- for sentence in combined_sentences:
123
- if len(sentence) <= max_chars:
124
- final_chunks.append(sentence)
125
- else:
126
- words = sentence.split()
127
- current_chunk = ""
128
-
129
- for word in words:
130
- test_chunk = current_chunk + " " + word if current_chunk else word
131
- if len(test_chunk) <= max_chars:
132
- current_chunk = test_chunk
133
- else:
134
- if current_chunk:
135
- final_chunks.append(current_chunk)
136
- current_chunk = word
137
-
138
- if current_chunk:
139
- final_chunks.append(current_chunk)
140
-
141
- return [chunk for chunk in final_chunks if chunk.strip()]
142
 
143
- def split_text_by_tokens(text, max_tokens=150):
144
- global tokenizer
145
-
146
- tokens = tokenizer.encode(text)
147
-
148
- if len(tokens) <= max_tokens:
149
  return [text]
150
 
 
 
151
  chunks = []
152
- words = text.split()
153
  current_chunk = ""
154
 
155
- for word in words:
156
- test_chunk = current_chunk + " " + word if current_chunk else word
157
- test_tokens = tokenizer.encode(test_chunk)
 
158
 
159
- if len(test_tokens) <= max_tokens:
160
- current_chunk = test_chunk
161
  else:
162
  if current_chunk:
163
- chunks.append(current_chunk)
164
- current_chunk = word
165
 
166
  if current_chunk:
167
- chunks.append(current_chunk)
168
 
169
- return chunks
170
 
171
- def process_prompt(prompt, voice, tokenizer, device):
 
 
 
 
 
 
172
  prompt = f"{voice}: {prompt}"
173
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
174
  start_token = torch.tensor([[128259]], dtype=torch.int64)
175
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
 
176
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
177
  attention_mask = torch.ones_like(modified_input_ids)
178
- return modified_input_ids.to(device), attention_mask.to(device)
 
 
 
179
 
180
- def parse_output(generated_ids):
 
 
181
  token_to_find = 128257
182
  token_to_remove = 128258
183
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
184
 
185
- if len(token_indices[1]) > 0:
186
- last_occurrence_idx = token_indices[1][-1].item()
187
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
 
 
 
 
 
 
188
  else:
189
- cropped_tensor = generated_ids
190
-
191
- processed_rows = []
192
- for row in cropped_tensor:
193
- masked_row = row[row != token_to_remove]
194
- processed_rows.append(masked_row)
195
-
196
- code_lists = []
197
- for row in processed_rows:
198
- row_length = row.size(0)
199
- new_length = (row_length // 7) * 7
200
- trimmed_row = row[:new_length]
201
- trimmed_row = [max(0, t - 128266) for t in trimmed_row]
202
- code_lists.append(trimmed_row)
203
-
204
- return code_lists[0] if code_lists and len(code_lists[0]) > 0 else []
205
 
206
- def redistribute_codes(code_list, snac_model):
 
207
  if not code_list or len(code_list) < 7:
208
- return np.zeros(12000)
209
-
210
  device = next(snac_model.parameters()).device
211
- layer_1 = []
212
- layer_2 = []
213
- layer_3 = []
214
 
215
  try:
216
- for i in range((len(code_list))//7):
217
- layer_1.append(max(0, code_list[7*i]))
218
- layer_2.append(max(0, code_list[7*i+1]-4096))
219
- layer_3.append(max(0, code_list[7*i+2]-(2*4096)))
220
- layer_3.append(max(0, code_list[7*i+3]-(3*4096)))
221
- layer_2.append(max(0, code_list[7*i+4]-(4*4096)))
222
- layer_3.append(max(0, code_list[7*i+5]-(5*4096)))
223
- layer_3.append(max(0, code_list[7*i+6]-(6*4096)))
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  codes = [
226
- torch.tensor(layer_1, device=device).unsqueeze(0),
227
- torch.tensor(layer_2, device=device).unsqueeze(0),
228
- torch.tensor(layer_3, device=device).unsqueeze(0)
229
  ]
230
 
231
- with torch.no_grad():
 
232
  audio_hat = snac_model.decode(codes)
233
- return audio_hat.detach().squeeze().cpu().numpy()
234
- except Exception as e:
235
- print(f"Error in redistribute_codes: {e}")
236
- return np.zeros(12000)
237
-
238
- def combine_audio_chunks(audio_chunks, pause_duration=0.3):
239
- if not audio_chunks:
240
- return np.array([])
241
-
242
- pause_samples = int(24000 * pause_duration)
243
- pause = np.zeros(pause_samples)
244
-
245
- combined_audio = []
246
- for i, chunk in enumerate(audio_chunks):
247
- if len(chunk) > 0:
248
- combined_audio.append(chunk)
249
- if i < len(audio_chunks) - 1:
250
- combined_audio.append(pause)
251
 
252
- if combined_audio:
253
- return np.concatenate(combined_audio)
254
- else:
255
- return np.array([])
256
 
257
- @spaces.GPU(duration=60) # Reduced duration to be more conservative
258
- def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600, voice="Elise"):
259
- """Generate speech for a single chunk"""
 
260
  global model, tokenizer, snac_model
261
 
262
  if not text_chunk.strip():
263
- return np.array([])
 
 
 
 
 
264
 
265
  try:
266
- input_ids, attention_mask = process_prompt(text_chunk, voice, tokenizer, device)
267
 
268
- with torch.no_grad():
 
 
269
  generated_ids = model.generate(
270
  input_ids=input_ids,
271
  attention_mask=attention_mask,
@@ -273,182 +248,164 @@ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_pe
273
  do_sample=True,
274
  temperature=temperature,
275
  top_p=top_p,
 
276
  repetition_penalty=repetition_penalty,
277
  num_return_sequences=1,
278
  eos_token_id=128258,
279
  pad_token_id=tokenizer.eos_token_id,
280
- use_cache=True
 
 
281
  )
282
 
283
- code_list = parse_output(generated_ids)
284
 
285
  if not code_list:
286
- return np.array([])
287
-
288
- audio_samples = redistribute_codes(code_list, snac_model)
 
 
 
 
 
 
 
 
 
 
 
289
  return audio_samples
290
 
291
  except Exception as e:
292
- print(f"Error generating speech chunk: {e}")
293
- return np.array([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600,
296
- voice="Elise", split_method="punctuation", max_chars=150, max_tokens=100,
297
- pause_duration=0.3, progress=gr.Progress()):
298
- """Main function to generate speech with text splitting"""
299
 
300
  if not text.strip():
301
  return None
302
 
303
  try:
304
- progress(0.05, "Splitting text...")
305
 
306
- if split_method == "punctuation":
307
- text_chunks = split_text_by_punctuation(text, max_chars)
308
- elif split_method == "tokens":
309
- text_chunks = split_text_by_tokens(text, max_tokens)
310
  else:
311
- text_chunks = [text]
312
 
313
- progress(0.1, f"Processing {len(text_chunks)} chunks...")
314
- print(f"Split text into {len(text_chunks)} chunks:")
315
- for i, chunk in enumerate(text_chunks):
316
- print(f"Chunk {i+1}: {chunk[:50]}...")
317
 
 
318
  audio_chunks = []
319
- for i, chunk in enumerate(text_chunks):
320
- progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
321
 
322
- audio = generate_speech_chunk(
323
  chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
324
  )
325
 
326
  if len(audio) > 0:
327
  audio_chunks.append(audio)
328
- print(f"Generated audio for chunk {i+1}: {len(audio)} samples ({len(audio)/24000:.2f}s)")
329
 
330
  if not audio_chunks:
331
  return None
332
 
333
- progress(0.9, "Combining audio chunks...")
334
- final_audio = combine_audio_chunks(audio_chunks, pause_duration)
335
 
336
- progress(1.0, "Complete!")
337
- print(f"Final audio: {len(final_audio)} samples ({len(final_audio)/24000:.2f}s)")
338
 
339
  return (24000, final_audio)
340
 
341
  except Exception as e:
342
- print(f"Error generating speech: {e}")
343
- import traceback
344
- traceback.print_exc()
345
  return None
346
 
347
- # [Rest of your Gradio interface code remains the same]
348
  examples = [
349
- ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ αžαžΆαžšαžΆαŸ” αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
350
- ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎοΏ½οΏ½αžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
351
  ]
352
 
353
- EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
354
-
355
- with gr.Blocks(title="Khmer Text-to-Speech") as demo:
356
- gr.Markdown(f"""
357
- # 🎡 Khmer Text-to-Speech
358
- **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
359
- Authentication: {'βœ… Pro Account' if auth_success else '❌ Anonymous (Limited)'}
360
-
361
- αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
362
- πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
363
- ✨ **New**: Supports long text with automatic splitting!
364
  """)
365
 
366
  text_input = gr.Textbox(
367
- label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
368
- placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžΆαž…αžœαŸ‚αž„αž”αžΆαž“)",
369
- lines=6
370
  )
371
 
372
- with gr.Accordion("πŸ“ Text Splitting Options", open=True):
373
- split_method = gr.Radio(
374
- choices=[
375
- ("Split by punctuation (recommended)", "punctuation"),
376
- ("Split by token count", "tokens"),
377
- ("No splitting", "none")
378
- ],
379
- value="punctuation",
380
- label="Text splitting method"
381
- )
382
-
383
- with gr.Row():
384
- max_chars = gr.Slider(
385
- minimum=50, maximum=300, value=150, step=25,
386
- label="Max characters per chunk"
387
- )
388
- max_tokens = gr.Slider(
389
- minimum=50, maximum=200, value=100, step=25,
390
- label="Max tokens per chunk"
391
- )
392
-
393
- pause_duration = gr.Slider(
394
- minimum=0.0, maximum=1.0, value=0.3, step=0.1,
395
- label="Pause between chunks (seconds)"
396
- )
397
-
398
- with gr.Accordion("πŸ”§ Advanced Settings", open=False):
399
- with gr.Row():
400
- temperature = gr.Slider(
401
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
402
- label="Temperature"
403
- )
404
- top_p = gr.Slider(
405
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
406
- label="Top P"
407
- )
408
- with gr.Row():
409
- repetition_penalty = gr.Slider(
410
- minimum=1.0, maximum=2.0, value=1.1, step=0.05,
411
- label="Repetition Penalty"
412
- )
413
- max_new_tokens = gr.Slider(
414
- minimum=100, maximum=800, value=600, step=100,
415
- label="Max tokens per chunk"
416
- )
417
 
418
  with gr.Row():
419
- submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg")
420
- clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
421
 
422
- audio_output = gr.Audio(
423
- label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
424
- type="numpy",
425
- show_label=True
426
- )
427
 
428
  gr.Examples(
429
  examples=examples,
430
  inputs=[text_input],
431
  outputs=audio_output,
432
- fn=lambda text: generate_speech(text),
433
  cache_examples=False,
434
  )
435
 
436
- submit_btn.click(
437
- fn=generate_speech,
438
- inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens,
439
- gr.State("Elise"), split_method, max_chars, max_tokens, pause_duration],
 
440
  outputs=audio_output
441
  )
442
 
443
  clear_btn.click(
444
  fn=lambda: (None, None),
445
- inputs=[],
446
  outputs=[text_input, audio_output]
447
  )
448
 
449
  if __name__ == "__main__":
450
- demo.queue(max_size=5).launch(
451
  share=False,
452
  server_name="0.0.0.0",
453
- server_port=7860
 
454
  )
 
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download, login
 
7
  import os
8
  import re
9
  import numpy as np
10
+ from torch.nn.attention import SDPABackend, sdpa_kernel
11
+ import torch.nn.functional as F
12
 
13
+ # Enable optimizations
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+ torch.set_float32_matmul_precision('medium') # or 'high' for better speed
17
 
 
18
  def setup_auth():
19
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
20
  if hf_token:
 
25
  except Exception as e:
26
  print(f"⚠️ Failed to login to Hugging Face: {e}")
27
  return False
28
+ return False
 
 
29
 
 
30
  auth_success = setup_auth()
31
 
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  print(f"Using device: {device}")
 
34
 
35
+ # Global model variables
36
  snac_model = None
37
  model = None
38
  tokenizer = None
 
43
  print("Loading SNAC model...")
44
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
45
  snac_model = snac_model.to(device)
46
+ snac_model.eval() # Set to eval mode
47
 
48
+ # Optimize SNAC model
49
+ if device == "cuda":
50
+ snac_model = torch.compile(snac_model, mode="reduce-overhead")
51
 
52
+ model_name = "mrrtmob/tts-khm-4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  print("Loading main model...")
55
  if device == "cuda":
56
  model = AutoModelForCausalLM.from_pretrained(
57
+ model_name,
58
  torch_dtype=torch.bfloat16,
59
+ low_cpu_mem_usage=True,
60
+ attn_implementation="flash_attention_2", # Use Flash Attention if available
61
  )
62
  model = model.to(device)
63
+
64
+ # Optimize main model with torch.compile
65
+ model = torch.compile(model, mode="reduce-overhead")
66
  else:
67
  model = AutoModelForCausalLM.from_pretrained(
68
+ model_name,
69
  torch_dtype=torch.float32
70
  )
71
 
72
+ model.eval()
73
+
74
  print("Loading tokenizer...")
75
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
76
  if tokenizer.pad_token is None:
77
  tokenizer.pad_token = tokenizer.eos_token
78
 
79
+ print(f"Models loaded and optimized")
80
 
81
+ # Load models
82
  load_models()
83
 
84
+ # Optimized text processing with caching
85
+ text_cache = {}
86
+ audio_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def smart_split_text(text, max_chars=120):
89
+ """Optimized text splitting for better performance"""
90
+ if len(text) <= max_chars:
 
 
 
91
  return [text]
92
 
93
+ # Use simple sentence splitting for speed
94
+ sentences = re.split(r'([αŸ”!?])', text)
95
  chunks = []
 
96
  current_chunk = ""
97
 
98
+ for i in range(0, len(sentences), 2):
99
+ sentence = sentences[i]
100
+ if i + 1 < len(sentences):
101
+ sentence += sentences[i + 1]
102
 
103
+ if len(current_chunk + sentence) <= max_chars:
104
+ current_chunk += sentence
105
  else:
106
  if current_chunk:
107
+ chunks.append(current_chunk.strip())
108
+ current_chunk = sentence
109
 
110
  if current_chunk:
111
+ chunks.append(current_chunk.strip())
112
 
113
+ return [chunk for chunk in chunks if chunk.strip()]
114
 
115
+ def process_prompt_fast(prompt, voice, tokenizer, device):
116
+ """Optimized prompt processing"""
117
+ # Cache tokenization if same prompt
118
+ cache_key = f"{voice}:{prompt}"
119
+ if cache_key in text_cache:
120
+ return text_cache[cache_key]
121
+
122
  prompt = f"{voice}: {prompt}"
123
+
124
+ # Batch tokenize for efficiency
125
+ encoded = tokenizer(
126
+ prompt,
127
+ return_tensors="pt",
128
+ padding=False,
129
+ truncation=True,
130
+ max_length=512
131
+ )
132
+
133
+ input_ids = encoded.input_ids
134
  start_token = torch.tensor([[128259]], dtype=torch.int64)
135
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
136
+
137
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
138
  attention_mask = torch.ones_like(modified_input_ids)
139
+
140
+ result = (modified_input_ids.to(device), attention_mask.to(device))
141
+ text_cache[cache_key] = result
142
+ return result
143
 
144
+ def parse_output_fast(generated_ids):
145
+ """Optimized output parsing"""
146
+ # Vectorized operations for speed
147
  token_to_find = 128257
148
  token_to_remove = 128258
 
149
 
150
+ # Find last occurrence efficiently
151
+ mask = (generated_ids == token_to_find)
152
+ if mask.any():
153
+ indices = torch.where(mask)
154
+ if len(indices[1]) > 0:
155
+ last_idx = indices[1][-1].item()
156
+ cropped = generated_ids[:, last_idx+1:]
157
+ else:
158
+ cropped = generated_ids
159
  else:
160
+ cropped = generated_ids
161
+
162
+ # Remove unwanted tokens
163
+ for row in cropped:
164
+ filtered = row[row != token_to_remove]
165
+ if len(filtered) >= 7:
166
+ # Trim to multiple of 7
167
+ new_length = (len(filtered) // 7) * 7
168
+ trimmed = filtered[:new_length]
169
+ # Vectorized subtraction and clipping
170
+ codes = torch.clamp(trimmed - 128266, min=0)
171
+ return codes.tolist()
172
+
173
+ return []
 
 
174
 
175
+ def redistribute_codes_fast(code_list, snac_model):
176
+ """Optimized code redistribution"""
177
  if not code_list or len(code_list) < 7:
178
+ return np.zeros(6000, dtype=np.float32) # Shorter silence
179
+
180
  device = next(snac_model.parameters()).device
 
 
 
181
 
182
  try:
183
+ # Vectorized processing
184
+ num_frames = len(code_list) // 7
185
+ codes_array = np.array(code_list[:num_frames * 7]).reshape(-1, 7)
 
 
 
 
 
186
 
187
+ # Vectorized layer extraction
188
+ layer_1 = codes_array[:, 0]
189
+ layer_2_indices = [1, 4]
190
+ layer_3_indices = [2, 3, 5, 6]
191
+
192
+ layer_2 = []
193
+ layer_3 = []
194
+
195
+ for i in range(num_frames):
196
+ layer_2.extend([
197
+ max(0, codes_array[i, 1] - 4096),
198
+ max(0, codes_array[i, 4] - (4*4096))
199
+ ])
200
+ layer_3.extend([
201
+ max(0, codes_array[i, 2] - (2*4096)),
202
+ max(0, codes_array[i, 3] - (3*4096)),
203
+ max(0, codes_array[i, 5] - (5*4096)),
204
+ max(0, codes_array[i, 6] - (6*4096))
205
+ ])
206
+
207
+ # Create tensors efficiently
208
  codes = [
209
+ torch.tensor(layer_1, device=device, dtype=torch.long).unsqueeze(0),
210
+ torch.tensor(layer_2, device=device, dtype=torch.long).unsqueeze(0),
211
+ torch.tensor(layer_3, device=device, dtype=torch.long).unsqueeze(0)
212
  ]
213
 
214
+ # Generate audio with optimizations
215
+ with torch.no_grad(), torch.autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
216
  audio_hat = snac_model.decode(codes)
217
+
218
+ return audio_hat.detach().squeeze().cpu().numpy().astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ except Exception as e:
221
+ print(f"Error in redistribute_codes_fast: {e}")
222
+ return np.zeros(6000, dtype=np.float32)
 
223
 
224
+ @spaces.GPU(duration=45) # Shorter duration for faster allocation
225
+ def generate_speech_chunk_fast(text_chunk, temperature=0.7, top_p=0.9, repetition_penalty=1.1,
226
+ max_new_tokens=400, voice="Elise"):
227
+ """Optimized speech generation"""
228
  global model, tokenizer, snac_model
229
 
230
  if not text_chunk.strip():
231
+ return np.array([], dtype=np.float32)
232
+
233
+ # Check cache first
234
+ cache_key = f"{text_chunk}:{temperature}:{top_p}:{max_new_tokens}"
235
+ if cache_key in audio_cache:
236
+ return audio_cache[cache_key]
237
 
238
  try:
239
+ input_ids, attention_mask = process_prompt_fast(text_chunk, voice, tokenizer, device)
240
 
241
+ # Optimized generation parameters
242
+ with torch.no_grad(), torch.autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
243
+ # Use optimized generation settings
244
  generated_ids = model.generate(
245
  input_ids=input_ids,
246
  attention_mask=attention_mask,
 
248
  do_sample=True,
249
  temperature=temperature,
250
  top_p=top_p,
251
+ top_k=50, # Add top_k for faster sampling
252
  repetition_penalty=repetition_penalty,
253
  num_return_sequences=1,
254
  eos_token_id=128258,
255
  pad_token_id=tokenizer.eos_token_id,
256
+ use_cache=True,
257
+ # Optimization flags
258
+ num_beams=1, # Greedy-like but with sampling
259
  )
260
 
261
+ code_list = parse_output_fast(generated_ids)
262
 
263
  if not code_list:
264
+ return np.array([], dtype=np.float32)
265
+
266
+ audio_samples = redistribute_codes_fast(code_list, snac_model)
267
+
268
+ # Cache result if successful
269
+ if len(audio_samples) > 0:
270
+ audio_cache[cache_key] = audio_samples
271
+ # Limit cache size
272
+ if len(audio_cache) > 100:
273
+ # Remove oldest entries
274
+ keys = list(audio_cache.keys())
275
+ for k in keys[:20]:
276
+ del audio_cache[k]
277
+
278
  return audio_samples
279
 
280
  except Exception as e:
281
+ print(f"Error in chunk generation: {e}")
282
+ return np.array([], dtype=np.float32)
283
+
284
+ def combine_audio_fast(audio_chunks, pause_duration=0.2):
285
+ """Fast audio combination"""
286
+ if not audio_chunks:
287
+ return np.array([], dtype=np.float32)
288
+
289
+ # Shorter pauses for faster speech
290
+ pause_samples = int(24000 * pause_duration)
291
+ pause = np.zeros(pause_samples, dtype=np.float32)
292
+
293
+ # Pre-calculate total length for efficiency
294
+ total_length = sum(len(chunk) for chunk in audio_chunks) + pause_samples * (len(audio_chunks) - 1)
295
+ combined = np.empty(total_length, dtype=np.float32)
296
+
297
+ pos = 0
298
+ for i, chunk in enumerate(audio_chunks):
299
+ if len(chunk) > 0:
300
+ combined[pos:pos+len(chunk)] = chunk
301
+ pos += len(chunk)
302
+
303
+ if i < len(audio_chunks) - 1:
304
+ combined[pos:pos+pause_samples] = pause
305
+ pos += pause_samples
306
+
307
+ return combined[:pos] # Trim to actual length
308
 
309
+ def generate_speech_fast(text, temperature=0.7, top_p=0.9, repetition_penalty=1.1,
310
+ max_new_tokens=400, voice="Elise", split_method="punctuation",
311
+ max_chars=120, pause_duration=0.2, progress=gr.Progress()):
312
+ """Optimized main generation function"""
313
 
314
  if not text.strip():
315
  return None
316
 
317
  try:
318
+ progress(0.05, "Processing...")
319
 
320
+ # Fast text splitting
321
+ if split_method == "punctuation" and len(text) > max_chars:
322
+ chunks = smart_split_text(text, max_chars)
 
323
  else:
324
+ chunks = [text]
325
 
326
+ progress(0.1, f"Generating {len(chunks)} chunks...")
327
+ print(f"Processing {len(chunks)} chunks")
 
 
328
 
329
+ # Parallel-like processing (sequential but optimized)
330
  audio_chunks = []
331
+ for i, chunk in enumerate(chunks):
332
+ progress(0.1 + 0.8 * (i / len(chunks)), f"Chunk {i+1}/{len(chunks)}")
333
 
334
+ audio = generate_speech_chunk_fast(
335
  chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
336
  )
337
 
338
  if len(audio) > 0:
339
  audio_chunks.append(audio)
 
340
 
341
  if not audio_chunks:
342
  return None
343
 
344
+ progress(0.95, "Combining...")
345
+ final_audio = combine_audio_fast(audio_chunks, pause_duration)
346
 
347
+ progress(1.0, "Done!")
348
+ print(f"Generated {len(final_audio)/24000:.1f}s of audio")
349
 
350
  return (24000, final_audio)
351
 
352
  except Exception as e:
353
+ print(f"Generation error: {e}")
 
 
354
  return None
355
 
356
+ # Simplified Gradio interface for speed
357
  examples = [
358
+ ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡αžαžΆαžšαžΆαŸ”"],
359
+ ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž“αž·αž™αžΆαž™αž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ”"],
360
  ]
361
 
362
+ with gr.Blocks(title="Fast Khmer TTS", theme="soft") as demo:
363
+ gr.Markdown("""
364
+ # ⚑ Fast Khmer Text-to-Speech
365
+ **Optimized for speed and efficiency**
 
 
 
 
 
 
 
366
  """)
367
 
368
  text_input = gr.Textbox(
369
+ label="Khmer Text",
370
+ placeholder="Enter Khmer text here...",
371
+ lines=3
372
  )
373
 
374
+ with gr.Row():
375
+ max_chars = gr.Slider(80, 200, 120, step=20, label="Chunk Size")
376
+ pause_duration = gr.Slider(0.1, 0.5, 0.2, step=0.1, label="Pause Duration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  with gr.Row():
379
+ generate_btn = gr.Button("🎀 Generate", variant="primary")
380
+ clear_btn = gr.Button("Clear")
381
 
382
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
 
 
 
 
383
 
384
  gr.Examples(
385
  examples=examples,
386
  inputs=[text_input],
387
  outputs=audio_output,
388
+ fn=lambda text: generate_speech_fast(text),
389
  cache_examples=False,
390
  )
391
 
392
+ generate_btn.click(
393
+ fn=generate_speech_fast,
394
+ inputs=[text_input, gr.State(0.7), gr.State(0.9), gr.State(1.1),
395
+ gr.State(400), gr.State("Elise"), gr.State("punctuation"),
396
+ max_chars, pause_duration],
397
  outputs=audio_output
398
  )
399
 
400
  clear_btn.click(
401
  fn=lambda: (None, None),
 
402
  outputs=[text_input, audio_output]
403
  )
404
 
405
  if __name__ == "__main__":
406
+ demo.queue(max_size=3, api_open=False).launch(
407
  share=False,
408
  server_name="0.0.0.0",
409
+ server_port=7860,
410
+ show_error=True
411
  )
requirements.txt CHANGED
@@ -10,3 +10,4 @@ scipy
10
  openai
11
  huggingface-hub
12
  accelerate
 
 
10
  openai
11
  huggingface-hub
12
  accelerate
13
+ flash-attn