mrrtmob commited on
Commit
844e3a3
Β·
1 Parent(s): 21ac5de

Remove flash-attn from requirements

Browse files
Files changed (2) hide show
  1. app.py +277 -234
  2. requirements.txt +1 -2
app.py CHANGED
@@ -4,17 +4,14 @@ import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download, login
 
7
  import os
8
  import re
9
  import numpy as np
10
- from torch.nn.attention import SDPABackend, sdpa_kernel
11
- import torch.nn.functional as F
12
 
13
- # Enable optimizations
14
- torch.backends.cuda.matmul.allow_tf32 = True
15
- torch.backends.cudnn.allow_tf32 = True
16
- torch.set_float32_matmul_precision('medium') # or 'high' for better speed
17
 
 
18
  def setup_auth():
19
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
20
  if hf_token:
@@ -25,14 +22,19 @@ def setup_auth():
25
  except Exception as e:
26
  print(f"⚠️ Failed to login to Hugging Face: {e}")
27
  return False
28
- return False
 
 
29
 
 
30
  auth_success = setup_auth()
31
 
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  print(f"Using device: {device}")
 
34
 
35
- # Global model variables
36
  snac_model = None
37
  model = None
38
  tokenizer = None
@@ -43,204 +45,227 @@ def load_models():
43
  print("Loading SNAC model...")
44
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
45
  snac_model = snac_model.to(device)
46
- snac_model.eval() # Set to eval mode
47
-
48
- # Optimize SNAC model
49
- if device == "cuda":
50
- snac_model = torch.compile(snac_model, mode="reduce-overhead")
51
 
52
  model_name = "mrrtmob/tts-khm-4"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  print("Loading main model...")
55
  if device == "cuda":
56
  model = AutoModelForCausalLM.from_pretrained(
57
- model_name,
58
  torch_dtype=torch.bfloat16,
59
- low_cpu_mem_usage=True,
60
- attn_implementation="flash_attention_2", # Use Flash Attention if available
61
  )
62
  model = model.to(device)
63
-
64
- # Optimize main model with torch.compile
65
- model = torch.compile(model, mode="reduce-overhead")
66
  else:
67
  model = AutoModelForCausalLM.from_pretrained(
68
- model_name,
69
  torch_dtype=torch.float32
70
  )
71
 
72
- model.eval()
73
-
74
  print("Loading tokenizer...")
75
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
76
  if tokenizer.pad_token is None:
77
  tokenizer.pad_token = tokenizer.eos_token
78
 
79
- print(f"Models loaded and optimized")
80
 
81
- # Load models
82
  load_models()
83
 
84
- # Optimized text processing with caching
85
- text_cache = {}
86
- audio_cache = {}
87
-
88
- def smart_split_text(text, max_chars=120):
89
- """Optimized text splitting for better performance"""
90
- if len(text) <= max_chars:
91
- return [text]
92
 
93
- # Use simple sentence splitting for speed
94
- sentences = re.split(r'([αŸ”!?])', text)
95
- chunks = []
96
- current_chunk = ""
97
 
 
98
  for i in range(0, len(sentences), 2):
99
  sentence = sentences[i]
100
  if i + 1 < len(sentences):
101
  sentence += sentences[i + 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- if len(current_chunk + sentence) <= max_chars:
104
- current_chunk += sentence
105
  else:
106
  if current_chunk:
107
- chunks.append(current_chunk.strip())
108
- current_chunk = sentence
109
 
110
  if current_chunk:
111
- chunks.append(current_chunk.strip())
112
 
113
- return [chunk for chunk in chunks if chunk.strip()]
114
 
115
- def process_prompt_fast(prompt, voice, tokenizer, device):
116
- """Optimized prompt processing"""
117
- # Cache tokenization if same prompt
118
- cache_key = f"{voice}:{prompt}"
119
- if cache_key in text_cache:
120
- return text_cache[cache_key]
121
-
122
  prompt = f"{voice}: {prompt}"
123
-
124
- # Batch tokenize for efficiency
125
- encoded = tokenizer(
126
- prompt,
127
- return_tensors="pt",
128
- padding=False,
129
- truncation=True,
130
- max_length=512
131
- )
132
-
133
- input_ids = encoded.input_ids
134
  start_token = torch.tensor([[128259]], dtype=torch.int64)
135
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
136
-
137
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
138
  attention_mask = torch.ones_like(modified_input_ids)
139
-
140
- result = (modified_input_ids.to(device), attention_mask.to(device))
141
- text_cache[cache_key] = result
142
- return result
143
 
144
- def parse_output_fast(generated_ids):
145
- """Optimized output parsing"""
146
- # Vectorized operations for speed
147
  token_to_find = 128257
148
  token_to_remove = 128258
 
149
 
150
- # Find last occurrence efficiently
151
- mask = (generated_ids == token_to_find)
152
- if mask.any():
153
- indices = torch.where(mask)
154
- if len(indices[1]) > 0:
155
- last_idx = indices[1][-1].item()
156
- cropped = generated_ids[:, last_idx+1:]
157
- else:
158
- cropped = generated_ids
159
  else:
160
- cropped = generated_ids
161
-
162
- # Remove unwanted tokens
163
- for row in cropped:
164
- filtered = row[row != token_to_remove]
165
- if len(filtered) >= 7:
166
- # Trim to multiple of 7
167
- new_length = (len(filtered) // 7) * 7
168
- trimmed = filtered[:new_length]
169
- # Vectorized subtraction and clipping
170
- codes = torch.clamp(trimmed - 128266, min=0)
171
- return codes.tolist()
172
-
173
- return []
 
 
174
 
175
- def redistribute_codes_fast(code_list, snac_model):
176
- """Optimized code redistribution"""
177
  if not code_list or len(code_list) < 7:
178
- return np.zeros(6000, dtype=np.float32) # Shorter silence
179
-
180
  device = next(snac_model.parameters()).device
 
 
 
181
 
182
  try:
183
- # Vectorized processing
184
- num_frames = len(code_list) // 7
185
- codes_array = np.array(code_list[:num_frames * 7]).reshape(-1, 7)
 
 
 
 
 
186
 
187
- # Vectorized layer extraction
188
- layer_1 = codes_array[:, 0]
189
- layer_2_indices = [1, 4]
190
- layer_3_indices = [2, 3, 5, 6]
191
-
192
- layer_2 = []
193
- layer_3 = []
194
-
195
- for i in range(num_frames):
196
- layer_2.extend([
197
- max(0, codes_array[i, 1] - 4096),
198
- max(0, codes_array[i, 4] - (4*4096))
199
- ])
200
- layer_3.extend([
201
- max(0, codes_array[i, 2] - (2*4096)),
202
- max(0, codes_array[i, 3] - (3*4096)),
203
- max(0, codes_array[i, 5] - (5*4096)),
204
- max(0, codes_array[i, 6] - (6*4096))
205
- ])
206
-
207
- # Create tensors efficiently
208
  codes = [
209
- torch.tensor(layer_1, device=device, dtype=torch.long).unsqueeze(0),
210
- torch.tensor(layer_2, device=device, dtype=torch.long).unsqueeze(0),
211
- torch.tensor(layer_3, device=device, dtype=torch.long).unsqueeze(0)
212
  ]
213
 
214
- # Generate audio with optimizations
215
- with torch.no_grad(), torch.autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
216
  audio_hat = snac_model.decode(codes)
217
-
218
- return audio_hat.detach().squeeze().cpu().numpy().astype(np.float32)
219
-
220
  except Exception as e:
221
- print(f"Error in redistribute_codes_fast: {e}")
222
- return np.zeros(6000, dtype=np.float32)
223
 
224
- @spaces.GPU(duration=45) # Shorter duration for faster allocation
225
- def generate_speech_chunk_fast(text_chunk, temperature=0.7, top_p=0.9, repetition_penalty=1.1,
226
- max_new_tokens=400, voice="Elise"):
227
- """Optimized speech generation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  global model, tokenizer, snac_model
229
 
230
  if not text_chunk.strip():
231
- return np.array([], dtype=np.float32)
232
-
233
- # Check cache first
234
- cache_key = f"{text_chunk}:{temperature}:{top_p}:{max_new_tokens}"
235
- if cache_key in audio_cache:
236
- return audio_cache[cache_key]
237
 
238
  try:
239
- input_ids, attention_mask = process_prompt_fast(text_chunk, voice, tokenizer, device)
240
 
241
- # Optimized generation parameters
242
- with torch.no_grad(), torch.autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
243
- # Use optimized generation settings
244
  generated_ids = model.generate(
245
  input_ids=input_ids,
246
  attention_mask=attention_mask,
@@ -248,164 +273,182 @@ def generate_speech_chunk_fast(text_chunk, temperature=0.7, top_p=0.9, repetitio
248
  do_sample=True,
249
  temperature=temperature,
250
  top_p=top_p,
251
- top_k=50, # Add top_k for faster sampling
252
  repetition_penalty=repetition_penalty,
253
  num_return_sequences=1,
254
  eos_token_id=128258,
255
  pad_token_id=tokenizer.eos_token_id,
256
- use_cache=True,
257
- # Optimization flags
258
- num_beams=1, # Greedy-like but with sampling
259
  )
260
 
261
- code_list = parse_output_fast(generated_ids)
262
 
263
  if not code_list:
264
- return np.array([], dtype=np.float32)
265
-
266
- audio_samples = redistribute_codes_fast(code_list, snac_model)
267
-
268
- # Cache result if successful
269
- if len(audio_samples) > 0:
270
- audio_cache[cache_key] = audio_samples
271
- # Limit cache size
272
- if len(audio_cache) > 100:
273
- # Remove oldest entries
274
- keys = list(audio_cache.keys())
275
- for k in keys[:20]:
276
- del audio_cache[k]
277
-
278
  return audio_samples
279
 
280
  except Exception as e:
281
- print(f"Error in chunk generation: {e}")
282
- return np.array([], dtype=np.float32)
283
-
284
- def combine_audio_fast(audio_chunks, pause_duration=0.2):
285
- """Fast audio combination"""
286
- if not audio_chunks:
287
- return np.array([], dtype=np.float32)
288
-
289
- # Shorter pauses for faster speech
290
- pause_samples = int(24000 * pause_duration)
291
- pause = np.zeros(pause_samples, dtype=np.float32)
292
-
293
- # Pre-calculate total length for efficiency
294
- total_length = sum(len(chunk) for chunk in audio_chunks) + pause_samples * (len(audio_chunks) - 1)
295
- combined = np.empty(total_length, dtype=np.float32)
296
-
297
- pos = 0
298
- for i, chunk in enumerate(audio_chunks):
299
- if len(chunk) > 0:
300
- combined[pos:pos+len(chunk)] = chunk
301
- pos += len(chunk)
302
-
303
- if i < len(audio_chunks) - 1:
304
- combined[pos:pos+pause_samples] = pause
305
- pos += pause_samples
306
-
307
- return combined[:pos] # Trim to actual length
308
 
309
- def generate_speech_fast(text, temperature=0.7, top_p=0.9, repetition_penalty=1.1,
310
- max_new_tokens=400, voice="Elise", split_method="punctuation",
311
- max_chars=120, pause_duration=0.2, progress=gr.Progress()):
312
- """Optimized main generation function"""
313
 
314
  if not text.strip():
315
  return None
316
 
317
  try:
318
- progress(0.05, "Processing...")
319
 
320
- # Fast text splitting
321
- if split_method == "punctuation" and len(text) > max_chars:
322
- chunks = smart_split_text(text, max_chars)
 
323
  else:
324
- chunks = [text]
325
 
326
- progress(0.1, f"Generating {len(chunks)} chunks...")
327
- print(f"Processing {len(chunks)} chunks")
 
 
328
 
329
- # Parallel-like processing (sequential but optimized)
330
  audio_chunks = []
331
- for i, chunk in enumerate(chunks):
332
- progress(0.1 + 0.8 * (i / len(chunks)), f"Chunk {i+1}/{len(chunks)}")
333
 
334
- audio = generate_speech_chunk_fast(
335
  chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
336
  )
337
 
338
  if len(audio) > 0:
339
  audio_chunks.append(audio)
 
340
 
341
  if not audio_chunks:
342
  return None
343
 
344
- progress(0.95, "Combining...")
345
- final_audio = combine_audio_fast(audio_chunks, pause_duration)
346
 
347
- progress(1.0, "Done!")
348
- print(f"Generated {len(final_audio)/24000:.1f}s of audio")
349
 
350
  return (24000, final_audio)
351
 
352
  except Exception as e:
353
- print(f"Generation error: {e}")
 
 
354
  return None
355
 
356
- # Simplified Gradio interface for speed
357
  examples = [
358
- ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡αžαžΆαžšαžΆαŸ”"],
359
- ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž“αž·αž™αžΆαž™αž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ”"],
360
  ]
361
 
362
- with gr.Blocks(title="Fast Khmer TTS", theme="soft") as demo:
363
- gr.Markdown("""
364
- # ⚑ Fast Khmer Text-to-Speech
365
- **Optimized for speed and efficiency**
 
 
 
 
 
 
 
366
  """)
367
 
368
  text_input = gr.Textbox(
369
- label="Khmer Text",
370
- placeholder="Enter Khmer text here...",
371
- lines=3
372
  )
373
 
374
- with gr.Row():
375
- max_chars = gr.Slider(80, 200, 120, step=20, label="Chunk Size")
376
- pause_duration = gr.Slider(0.1, 0.5, 0.2, step=0.1, label="Pause Duration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  with gr.Row():
379
- generate_btn = gr.Button("🎀 Generate", variant="primary")
380
- clear_btn = gr.Button("Clear")
381
 
382
- audio_output = gr.Audio(label="Generated Speech", type="numpy")
 
 
 
 
383
 
384
  gr.Examples(
385
  examples=examples,
386
  inputs=[text_input],
387
  outputs=audio_output,
388
- fn=lambda text: generate_speech_fast(text),
389
  cache_examples=False,
390
  )
391
 
392
- generate_btn.click(
393
- fn=generate_speech_fast,
394
- inputs=[text_input, gr.State(0.7), gr.State(0.9), gr.State(1.1),
395
- gr.State(400), gr.State("Elise"), gr.State("punctuation"),
396
- max_chars, pause_duration],
397
  outputs=audio_output
398
  )
399
 
400
  clear_btn.click(
401
  fn=lambda: (None, None),
 
402
  outputs=[text_input, audio_output]
403
  )
404
 
405
  if __name__ == "__main__":
406
- demo.queue(max_size=3, api_open=False).launch(
407
  share=False,
408
  server_name="0.0.0.0",
409
- server_port=7860,
410
- show_error=True
411
  )
 
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download, login
7
+ from dotenv import load_dotenv
8
  import os
9
  import re
10
  import numpy as np
 
 
11
 
12
+ load_dotenv()
 
 
 
13
 
14
+ # Setup Hugging Face authentication
15
  def setup_auth():
16
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
17
  if hf_token:
 
22
  except Exception as e:
23
  print(f"⚠️ Failed to login to Hugging Face: {e}")
24
  return False
25
+ else:
26
+ print("⚠️ No HF token found. Running as anonymous user.")
27
+ return False
28
 
29
+ # Setup authentication before anything else
30
  auth_success = setup_auth()
31
 
32
+ # Check if CUDA is available
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print(f"Using device: {device}")
35
+ print(f"Authentication status: {'βœ… Logged in' if auth_success else '❌ Anonymous'}")
36
 
37
+ # Global variables to store models
38
  snac_model = None
39
  model = None
40
  tokenizer = None
 
45
  print("Loading SNAC model...")
46
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
47
  snac_model = snac_model.to(device)
 
 
 
 
 
48
 
49
  model_name = "mrrtmob/tts-khm-4"
50
 
51
+ print("Downloading model files...")
52
+ snapshot_download(
53
+ repo_id=model_name,
54
+ allow_patterns=[
55
+ "config.json",
56
+ "*.safetensors",
57
+ "model.safetensors.index.json",
58
+ "tokenizer.json",
59
+ "tokenizer_config.json",
60
+ "special_tokens_map.json",
61
+ "vocab.json",
62
+ "merges.txt"
63
+ ],
64
+ ignore_patterns=[
65
+ "optimizer.pt",
66
+ "pytorch_model.bin",
67
+ "training_args.bin",
68
+ "scheduler.pt"
69
+ ]
70
+ )
71
+
72
  print("Loading main model...")
73
  if device == "cuda":
74
  model = AutoModelForCausalLM.from_pretrained(
75
+ model_name,
76
  torch_dtype=torch.bfloat16,
77
+ low_cpu_mem_usage=True
 
78
  )
79
  model = model.to(device)
 
 
 
80
  else:
81
  model = AutoModelForCausalLM.from_pretrained(
82
+ model_name,
83
  torch_dtype=torch.float32
84
  )
85
 
 
 
86
  print("Loading tokenizer...")
87
  tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+
89
  if tokenizer.pad_token is None:
90
  tokenizer.pad_token = tokenizer.eos_token
91
 
92
+ print(f"Khmer TTS model loaded to {device}")
93
 
94
+ # Load models at startup
95
  load_models()
96
 
97
+ def split_text_by_punctuation(text, max_chars=200):
98
+ sentence_endings = r'[αŸ”!?]'
99
+ clause_separators = r'[,;:]'
 
 
 
 
 
100
 
101
+ sentences = re.split(f'({sentence_endings})', text)
 
 
 
102
 
103
+ combined_sentences = []
104
  for i in range(0, len(sentences), 2):
105
  sentence = sentences[i]
106
  if i + 1 < len(sentences):
107
  sentence += sentences[i + 1]
108
+ if sentence.strip():
109
+ combined_sentences.append(sentence.strip())
110
+
111
+ if len(combined_sentences) <= 1:
112
+ parts = re.split(f'({clause_separators})', text)
113
+ combined_sentences = []
114
+ for i in range(0, len(parts), 2):
115
+ part = parts[i]
116
+ if i + 1 < len(parts):
117
+ part += parts[i + 1]
118
+ if part.strip():
119
+ combined_sentences.append(part.strip())
120
+
121
+ final_chunks = []
122
+ for sentence in combined_sentences:
123
+ if len(sentence) <= max_chars:
124
+ final_chunks.append(sentence)
125
+ else:
126
+ words = sentence.split()
127
+ current_chunk = ""
128
+
129
+ for word in words:
130
+ test_chunk = current_chunk + " " + word if current_chunk else word
131
+ if len(test_chunk) <= max_chars:
132
+ current_chunk = test_chunk
133
+ else:
134
+ if current_chunk:
135
+ final_chunks.append(current_chunk)
136
+ current_chunk = word
137
+
138
+ if current_chunk:
139
+ final_chunks.append(current_chunk)
140
+
141
+ return [chunk for chunk in final_chunks if chunk.strip()]
142
+
143
+ def split_text_by_tokens(text, max_tokens=150):
144
+ global tokenizer
145
+
146
+ tokens = tokenizer.encode(text)
147
+
148
+ if len(tokens) <= max_tokens:
149
+ return [text]
150
+
151
+ chunks = []
152
+ words = text.split()
153
+ current_chunk = ""
154
+
155
+ for word in words:
156
+ test_chunk = current_chunk + " " + word if current_chunk else word
157
+ test_tokens = tokenizer.encode(test_chunk)
158
 
159
+ if len(test_tokens) <= max_tokens:
160
+ current_chunk = test_chunk
161
  else:
162
  if current_chunk:
163
+ chunks.append(current_chunk)
164
+ current_chunk = word
165
 
166
  if current_chunk:
167
+ chunks.append(current_chunk)
168
 
169
+ return chunks
170
 
171
+ def process_prompt(prompt, voice, tokenizer, device):
 
 
 
 
 
 
172
  prompt = f"{voice}: {prompt}"
173
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
174
  start_token = torch.tensor([[128259]], dtype=torch.int64)
175
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
 
176
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
177
  attention_mask = torch.ones_like(modified_input_ids)
178
+ return modified_input_ids.to(device), attention_mask.to(device)
 
 
 
179
 
180
+ def parse_output(generated_ids):
 
 
181
  token_to_find = 128257
182
  token_to_remove = 128258
183
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
184
 
185
+ if len(token_indices[1]) > 0:
186
+ last_occurrence_idx = token_indices[1][-1].item()
187
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
 
 
 
 
 
 
188
  else:
189
+ cropped_tensor = generated_ids
190
+
191
+ processed_rows = []
192
+ for row in cropped_tensor:
193
+ masked_row = row[row != token_to_remove]
194
+ processed_rows.append(masked_row)
195
+
196
+ code_lists = []
197
+ for row in processed_rows:
198
+ row_length = row.size(0)
199
+ new_length = (row_length // 7) * 7
200
+ trimmed_row = row[:new_length]
201
+ trimmed_row = [max(0, t - 128266) for t in trimmed_row]
202
+ code_lists.append(trimmed_row)
203
+
204
+ return code_lists[0] if code_lists and len(code_lists[0]) > 0 else []
205
 
206
+ def redistribute_codes(code_list, snac_model):
 
207
  if not code_list or len(code_list) < 7:
208
+ return np.zeros(12000)
209
+
210
  device = next(snac_model.parameters()).device
211
+ layer_1 = []
212
+ layer_2 = []
213
+ layer_3 = []
214
 
215
  try:
216
+ for i in range((len(code_list))//7):
217
+ layer_1.append(max(0, code_list[7*i]))
218
+ layer_2.append(max(0, code_list[7*i+1]-4096))
219
+ layer_3.append(max(0, code_list[7*i+2]-(2*4096)))
220
+ layer_3.append(max(0, code_list[7*i+3]-(3*4096)))
221
+ layer_2.append(max(0, code_list[7*i+4]-(4*4096)))
222
+ layer_3.append(max(0, code_list[7*i+5]-(5*4096)))
223
+ layer_3.append(max(0, code_list[7*i+6]-(6*4096)))
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  codes = [
226
+ torch.tensor(layer_1, device=device).unsqueeze(0),
227
+ torch.tensor(layer_2, device=device).unsqueeze(0),
228
+ torch.tensor(layer_3, device=device).unsqueeze(0)
229
  ]
230
 
231
+ with torch.no_grad():
 
232
  audio_hat = snac_model.decode(codes)
233
+ return audio_hat.detach().squeeze().cpu().numpy()
 
 
234
  except Exception as e:
235
+ print(f"Error in redistribute_codes: {e}")
236
+ return np.zeros(12000)
237
 
238
+ def combine_audio_chunks(audio_chunks, pause_duration=0.3):
239
+ if not audio_chunks:
240
+ return np.array([])
241
+
242
+ pause_samples = int(24000 * pause_duration)
243
+ pause = np.zeros(pause_samples)
244
+
245
+ combined_audio = []
246
+ for i, chunk in enumerate(audio_chunks):
247
+ if len(chunk) > 0:
248
+ combined_audio.append(chunk)
249
+ if i < len(audio_chunks) - 1:
250
+ combined_audio.append(pause)
251
+
252
+ if combined_audio:
253
+ return np.concatenate(combined_audio)
254
+ else:
255
+ return np.array([])
256
+
257
+ @spaces.GPU(duration=60) # Reduced duration to be more conservative
258
+ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600, voice="Elise"):
259
+ """Generate speech for a single chunk"""
260
  global model, tokenizer, snac_model
261
 
262
  if not text_chunk.strip():
263
+ return np.array([])
 
 
 
 
 
264
 
265
  try:
266
+ input_ids, attention_mask = process_prompt(text_chunk, voice, tokenizer, device)
267
 
268
+ with torch.no_grad():
 
 
269
  generated_ids = model.generate(
270
  input_ids=input_ids,
271
  attention_mask=attention_mask,
 
273
  do_sample=True,
274
  temperature=temperature,
275
  top_p=top_p,
 
276
  repetition_penalty=repetition_penalty,
277
  num_return_sequences=1,
278
  eos_token_id=128258,
279
  pad_token_id=tokenizer.eos_token_id,
280
+ use_cache=True
 
 
281
  )
282
 
283
+ code_list = parse_output(generated_ids)
284
 
285
  if not code_list:
286
+ return np.array([])
287
+
288
+ audio_samples = redistribute_codes(code_list, snac_model)
 
 
 
 
 
 
 
 
 
 
 
289
  return audio_samples
290
 
291
  except Exception as e:
292
+ print(f"Error generating speech chunk: {e}")
293
+ return np.array([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600,
296
+ voice="Elise", split_method="punctuation", max_chars=150, max_tokens=100,
297
+ pause_duration=0.3, progress=gr.Progress()):
298
+ """Main function to generate speech with text splitting"""
299
 
300
  if not text.strip():
301
  return None
302
 
303
  try:
304
+ progress(0.05, "Splitting text...")
305
 
306
+ if split_method == "punctuation":
307
+ text_chunks = split_text_by_punctuation(text, max_chars)
308
+ elif split_method == "tokens":
309
+ text_chunks = split_text_by_tokens(text, max_tokens)
310
  else:
311
+ text_chunks = [text]
312
 
313
+ progress(0.1, f"Processing {len(text_chunks)} chunks...")
314
+ print(f"Split text into {len(text_chunks)} chunks:")
315
+ for i, chunk in enumerate(text_chunks):
316
+ print(f"Chunk {i+1}: {chunk[:50]}...")
317
 
 
318
  audio_chunks = []
319
+ for i, chunk in enumerate(text_chunks):
320
+ progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
321
 
322
+ audio = generate_speech_chunk(
323
  chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
324
  )
325
 
326
  if len(audio) > 0:
327
  audio_chunks.append(audio)
328
+ print(f"Generated audio for chunk {i+1}: {len(audio)} samples ({len(audio)/24000:.2f}s)")
329
 
330
  if not audio_chunks:
331
  return None
332
 
333
+ progress(0.9, "Combining audio chunks...")
334
+ final_audio = combine_audio_chunks(audio_chunks, pause_duration)
335
 
336
+ progress(1.0, "Complete!")
337
+ print(f"Final audio: {len(final_audio)} samples ({len(final_audio)/24000:.2f}s)")
338
 
339
  return (24000, final_audio)
340
 
341
  except Exception as e:
342
+ print(f"Error generating speech: {e}")
343
+ import traceback
344
+ traceback.print_exc()
345
  return None
346
 
347
+ # [Rest of your Gradio interface code remains the same]
348
  examples = [
349
+ ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ αžαžΆαžšαžΆαŸ” αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
350
+ ["αžαŸ’αž‰αž»αŸ†αž’αžΆοΏ½οΏ½οΏ½αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
351
  ]
352
 
353
+ EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
354
+
355
+ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
356
+ gr.Markdown(f"""
357
+ # 🎡 Khmer Text-to-Speech
358
+ **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
359
+ Authentication: {'βœ… Pro Account' if auth_success else '❌ Anonymous (Limited)'}
360
+
361
+ αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
362
+ πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
363
+ ✨ **New**: Supports long text with automatic splitting!
364
  """)
365
 
366
  text_input = gr.Textbox(
367
+ label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
368
+ placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžΆαž…αžœαŸ‚αž„αž”αžΆαž“)",
369
+ lines=6
370
  )
371
 
372
+ with gr.Accordion("πŸ“ Text Splitting Options", open=True):
373
+ split_method = gr.Radio(
374
+ choices=[
375
+ ("Split by punctuation (recommended)", "punctuation"),
376
+ ("Split by token count", "tokens"),
377
+ ("No splitting", "none")
378
+ ],
379
+ value="punctuation",
380
+ label="Text splitting method"
381
+ )
382
+
383
+ with gr.Row():
384
+ max_chars = gr.Slider(
385
+ minimum=50, maximum=300, value=150, step=25,
386
+ label="Max characters per chunk"
387
+ )
388
+ max_tokens = gr.Slider(
389
+ minimum=50, maximum=200, value=100, step=25,
390
+ label="Max tokens per chunk"
391
+ )
392
+
393
+ pause_duration = gr.Slider(
394
+ minimum=0.0, maximum=1.0, value=0.3, step=0.1,
395
+ label="Pause between chunks (seconds)"
396
+ )
397
+
398
+ with gr.Accordion("πŸ”§ Advanced Settings", open=False):
399
+ with gr.Row():
400
+ temperature = gr.Slider(
401
+ minimum=0.1, maximum=1.5, value=0.6, step=0.05,
402
+ label="Temperature"
403
+ )
404
+ top_p = gr.Slider(
405
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
406
+ label="Top P"
407
+ )
408
+ with gr.Row():
409
+ repetition_penalty = gr.Slider(
410
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
411
+ label="Repetition Penalty"
412
+ )
413
+ max_new_tokens = gr.Slider(
414
+ minimum=100, maximum=800, value=600, step=100,
415
+ label="Max tokens per chunk"
416
+ )
417
 
418
  with gr.Row():
419
+ submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg")
420
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
421
 
422
+ audio_output = gr.Audio(
423
+ label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
424
+ type="numpy",
425
+ show_label=True
426
+ )
427
 
428
  gr.Examples(
429
  examples=examples,
430
  inputs=[text_input],
431
  outputs=audio_output,
432
+ fn=lambda text: generate_speech(text),
433
  cache_examples=False,
434
  )
435
 
436
+ submit_btn.click(
437
+ fn=generate_speech,
438
+ inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens,
439
+ gr.State("Elise"), split_method, max_chars, max_tokens, pause_duration],
 
440
  outputs=audio_output
441
  )
442
 
443
  clear_btn.click(
444
  fn=lambda: (None, None),
445
+ inputs=[],
446
  outputs=[text_input, audio_output]
447
  )
448
 
449
  if __name__ == "__main__":
450
+ demo.queue(max_size=5).launch(
451
  share=False,
452
  server_name="0.0.0.0",
453
+ server_port=7860
 
454
  )
requirements.txt CHANGED
@@ -9,5 +9,4 @@ gradio
9
  scipy
10
  openai
11
  huggingface-hub
12
- accelerate
13
- flash-attn
 
9
  scipy
10
  openai
11
  huggingface-hub
12
+ accelerate