mrrtmob commited on
Commit
d1e3c74
Β·
1 Parent(s): 94f2fb2

No code changes made.

Browse files
Files changed (1) hide show
  1. app.py +330 -109
app.py CHANGED
@@ -5,114 +5,231 @@ import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
 
 
 
 
8
  load_dotenv()
 
9
  # Check if CUDA is available
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- print("Loading SNAC model...")
12
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
- snac_model = snac_model.to(device)
14
- model_name = "mrrtmob/tts-khm-4"
15
- # Download only model config and safetensors
16
- snapshot_download(
17
- repo_id=model_name,
18
- allow_patterns=[
19
- "config.json",
20
- "*.safetensors",
21
- "model.safetensors.index.json",
22
- ],
23
- ignore_patterns=[
24
- "optimizer.pt",
25
- "pytorch_model.bin",
26
- "training_args.bin",
27
- "scheduler.pt",
28
- "tokenizer.json",
29
- "tokenizer_config.json",
30
- "special_tokens_map.json",
31
- "vocab.json",
32
- "merges.txt",
33
- "tokenizer.*"
34
- ]
35
- )
36
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
- model.to(device)
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- print(f"Khmer TTS model loaded to {device}")
40
- # Process text prompt
41
- def process_prompt(prompt, voice, tokenizer, device):
42
- prompt = f"{voice}: {prompt}"
43
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
 
45
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
46
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
 
47
 
48
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
49
 
50
- # No padding needed for single input
51
- attention_mask = torch.ones_like(modified_input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return modified_input_ids.to(device), attention_mask.to(device)
54
- # Parse output tokens to audio
55
  def parse_output(generated_ids):
56
  token_to_find = 128257
57
  token_to_remove = 128258
58
-
59
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
 
60
  if len(token_indices[1]) > 0:
61
  last_occurrence_idx = token_indices[1][-1].item()
62
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
  else:
64
  cropped_tensor = generated_ids
65
-
66
  processed_rows = []
67
  for row in cropped_tensor:
68
  masked_row = row[row != token_to_remove]
69
  processed_rows.append(masked_row)
70
-
71
  code_lists = []
72
  for row in processed_rows:
73
  row_length = row.size(0)
74
  new_length = (row_length // 7) * 7
75
  trimmed_row = row[:new_length]
76
- trimmed_row = [t - 128266 for t in trimmed_row]
77
  code_lists.append(trimmed_row)
78
-
79
- return code_lists[0] # Return just the first one for single sample
80
- # Redistribute codes for audio generation
81
- def redistribute_codes(code_list, snac_model):
82
- device = next(snac_model.parameters()).device # Get the device of SNAC model
83
 
 
 
 
 
 
 
 
84
  layer_1 = []
85
  layer_2 = []
86
  layer_3 = []
87
- for i in range((len(code_list)+1)//7):
88
- layer_1.append(code_list[7*i])
89
- layer_2.append(code_list[7*i+1]-4096)
90
- layer_3.append(code_list[7*i+2]-(2*4096))
91
- layer_3.append(code_list[7*i+3]-(3*4096))
92
- layer_2.append(code_list[7*i+4]-(4*4096))
93
- layer_3.append(code_list[7*i+5]-(5*4096))
94
- layer_3.append(code_list[7*i+6]-(6*4096))
95
-
96
- # Move tensors to the same device as the SNAC model
97
- codes = [
98
- torch.tensor(layer_1, device=device).unsqueeze(0),
99
- torch.tensor(layer_2, device=device).unsqueeze(0),
100
- torch.tensor(layer_3, device=device).unsqueeze(0)
101
- ]
102
-
103
- audio_hat = snac_model.decode(codes)
104
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
105
- # Main generation function
106
- @spaces.GPU()
107
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
108
- if not text.strip():
109
- return None
110
 
111
  try:
112
- progress(0.1, "Processing text...")
113
- input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- progress(0.3, "Generating speech tokens...")
116
  with torch.no_grad():
117
  generated_ids = model.generate(
118
  input_ids=input_ids,
@@ -124,77 +241,177 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
124
  repetition_penalty=repetition_penalty,
125
  num_return_sequences=1,
126
  eos_token_id=128258,
 
 
127
  )
128
 
129
- progress(0.6, "Processing speech tokens...")
130
  code_list = parse_output(generated_ids)
131
 
132
- progress(0.8, "Converting to audio...")
 
 
133
  audio_samples = redistribute_codes(code_list, snac_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- return (24000, audio_samples) # Return sample rate and audio
136
  except Exception as e:
137
  print(f"Error generating speech: {e}")
 
 
138
  return None
139
- # Examples for the UI - Khmer text examples
 
140
  examples = [
141
- ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ តអរអ αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
142
- ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
143
- ["αžαŸ’αž‰αž»αŸ†αžšαžŸαŸ‹αž“αŸ…αž€αŸ’αž“αž»αž„αž‘αžΈαž€αŸ’αžšαž»αž„αž—αŸ’αž“αŸ†αž–αŸαž‰ αž αžΎαž™αž˜αžΆαž“αž…αžšαžΆαž…αžšαžŽαŸ <gasp> αž…αŸ’αžšαžΎαž“αžŽαžΆαžŸαŸ‹αŸ”"],
144
- ["αž–αŸαž›αžαŸ’αž›αŸ‡ αž–αŸαž›αžαŸ’αž‰αž»αŸ†αž“αž·αž™αžΆαž™αž…αŸ’αžšαžΎοΏ½οΏ½αž–αŸαž€ αžαŸ’αž‰αž»αŸ†αžαŸ’αžšαžΌαžœ <cough> αžŸαž»αŸ†αž‘αŸ„αžŸαŸ”"],
145
- ["αž€αžΆαžšαž“αž·αž™αžΆαž™αž“αŸ…αž…αŸ†αž–αŸ„αŸ‡αž˜αž»αžαžŸαžΆαž’αžΆαžšαžŽαŸˆ αž’αžΆαž…αž˜αžΆαž“αž€αžΆαžšαž–αž·αž”αžΆαž€αŸ” <groan> αž”αŸ‰αž»αž“αŸ’αžαŸ‚αž”αžΎαž αžΆαžαŸ‹ αž‚αŸαž’αžΆαž…αž’αŸ’αžœαžΎαž”αžΆαž“αŸ”"],
146
  ]
147
- # Available voices (commented out for simpler UI)
148
- # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
149
- # Available Emotive Tags
150
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
 
151
  # Create Gradio interface
152
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
153
  gr.Markdown(f"""
154
  # 🎡 Khmer Text-to-Speech
155
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
156
-
157
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
158
-
159
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
160
- """)
 
161
 
162
  text_input = gr.Textbox(
163
- label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
164
- placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡...",
165
- lines=4
166
  )
167
 
168
- # Voice selector (commented out)
169
- # voice = gr.Dropdown(
170
- # choices=VOICES,
171
- # value="tara",
172
- # label="Voice (αžŸαŸ†αž›αŸαž„)"
173
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Advanced Settings
176
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
177
  with gr.Row():
178
  temperature = gr.Slider(
179
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
180
- label="Temperature",
181
  info="Higher values create more expressive speech"
182
  )
183
  top_p = gr.Slider(
184
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
185
- label="Top P",
186
  info="Nucleus sampling threshold"
187
  )
188
  with gr.Row():
189
  repetition_penalty = gr.Slider(
190
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
191
- label="Repetition Penalty",
192
  info="Higher values discourage repetitive patterns"
193
  )
194
  max_new_tokens = gr.Slider(
195
- minimum=100, maximum=2000, value=1200, step=100,
196
- label="Max Length",
197
- info="Maximum length of generated audio"
198
  )
199
 
200
  with gr.Row():
@@ -202,12 +419,11 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
202
  clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
203
 
204
  audio_output = gr.Audio(
205
- label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
206
  type="numpy",
207
  show_label=True
208
  )
209
 
210
- # Set up examples (NO CACHE)
211
  gr.Examples(
212
  examples=examples,
213
  inputs=[text_input],
@@ -216,10 +432,10 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
216
  cache_examples=False,
217
  )
218
 
219
- # Set up event handlers
220
  submit_btn.click(
221
  fn=generate_speech,
222
- inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
 
223
  outputs=audio_output
224
  )
225
 
@@ -228,6 +444,11 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
228
  inputs=[],
229
  outputs=[text_input, audio_output]
230
  )
 
231
  # Launch the app
232
  if __name__ == "__main__":
233
- demo.queue().launch(share=False)
 
 
 
 
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
+ import os
9
+ import re
10
+ import numpy as np
11
+
12
  load_dotenv()
13
+
14
  # Check if CUDA is available
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Using device: {device}")
17
+
18
+ # Global variables to store models
19
+ snac_model = None
20
+ model = None
21
+ tokenizer = None
22
+
23
+ def load_models():
24
+ global snac_model, model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ print("Loading SNAC model...")
27
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
28
+ snac_model = snac_model.to(device)
29
 
30
+ model_name = "mrrtmob/tts-khm-4"
31
 
32
+ # Download specific files
33
+ print("Downloading model files...")
34
+ snapshot_download(
35
+ repo_id=model_name,
36
+ allow_patterns=[
37
+ "config.json",
38
+ "*.safetensors",
39
+ "model.safetensors.index.json",
40
+ "tokenizer.json",
41
+ "tokenizer_config.json",
42
+ "special_tokens_map.json",
43
+ "vocab.json",
44
+ "merges.txt"
45
+ ],
46
+ ignore_patterns=[
47
+ "optimizer.pt",
48
+ "pytorch_model.bin",
49
+ "training_args.bin",
50
+ "scheduler.pt"
51
+ ]
52
+ )
53
+
54
+ print("Loading main model...")
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_name,
57
+ torch_dtype=torch.bfloat16,
58
+ device_map="auto" if device == "cuda" else None
59
+ )
60
+
61
+ print("Loading tokenizer...")
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+
64
+ if tokenizer.pad_token is None:
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+
67
+ print(f"Khmer TTS model loaded to {device}")
68
+
69
+ # Load models at startup
70
+ load_models()
71
+
72
+ def split_text_by_punctuation(text, max_chars=200):
73
+ """Split text by punctuation marks, keeping sentences together when possible"""
74
+ # Khmer and common punctuation
75
+ sentence_endings = r'[αŸ”!?]'
76
+ clause_separators = r'[,;:]'
77
+
78
+ # First try to split by sentence endings
79
+ sentences = re.split(f'({sentence_endings})', text)
80
+
81
+ # Recombine sentences with their punctuation
82
+ combined_sentences = []
83
+ for i in range(0, len(sentences), 2):
84
+ sentence = sentences[i]
85
+ if i + 1 < len(sentences):
86
+ sentence += sentences[i + 1] # Add the punctuation back
87
+ if sentence.strip():
88
+ combined_sentences.append(sentence.strip())
89
+
90
+ # If no sentence endings found, split by clauses
91
+ if len(combined_sentences) <= 1:
92
+ parts = re.split(f'({clause_separators})', text)
93
+ combined_sentences = []
94
+ for i in range(0, len(parts), 2):
95
+ part = parts[i]
96
+ if i + 1 < len(parts):
97
+ part += parts[i + 1]
98
+ if part.strip():
99
+ combined_sentences.append(part.strip())
100
+
101
+ # Further split if sentences are too long
102
+ final_chunks = []
103
+ for sentence in combined_sentences:
104
+ if len(sentence) <= max_chars:
105
+ final_chunks.append(sentence)
106
+ else:
107
+ # Split long sentences by words
108
+ words = sentence.split()
109
+ current_chunk = ""
110
+
111
+ for word in words:
112
+ test_chunk = current_chunk + " " + word if current_chunk else word
113
+ if len(test_chunk) <= max_chars:
114
+ current_chunk = test_chunk
115
+ else:
116
+ if current_chunk:
117
+ final_chunks.append(current_chunk)
118
+ current_chunk = word
119
+
120
+ if current_chunk:
121
+ final_chunks.append(current_chunk)
122
+
123
+ return [chunk for chunk in final_chunks if chunk.strip()]
124
+
125
+ def split_text_by_tokens(text, max_tokens=150):
126
+ """Split text by token count"""
127
+ global tokenizer
128
+
129
+ # Tokenize the entire text first
130
+ tokens = tokenizer.encode(text)
131
+
132
+ if len(tokens) <= max_tokens:
133
+ return [text]
134
+
135
+ chunks = []
136
+ words = text.split()
137
+ current_chunk = ""
138
 
139
+ for word in words:
140
+ test_chunk = current_chunk + " " + word if current_chunk else word
141
+ test_tokens = tokenizer.encode(test_chunk)
142
+
143
+ if len(test_tokens) <= max_tokens:
144
+ current_chunk = test_chunk
145
+ else:
146
+ if current_chunk:
147
+ chunks.append(current_chunk)
148
+ current_chunk = word
149
+
150
+ if current_chunk:
151
+ chunks.append(current_chunk)
152
+
153
+ return chunks
154
+
155
+ def process_prompt(prompt, voice, tokenizer, device):
156
+ prompt = f"{voice}: {prompt}"
157
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
158
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
159
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
160
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
161
+ attention_mask = torch.ones_like(modified_input_ids)
162
  return modified_input_ids.to(device), attention_mask.to(device)
163
+
164
  def parse_output(generated_ids):
165
  token_to_find = 128257
166
  token_to_remove = 128258
 
167
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
168
+
169
  if len(token_indices[1]) > 0:
170
  last_occurrence_idx = token_indices[1][-1].item()
171
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
172
  else:
173
  cropped_tensor = generated_ids
174
+
175
  processed_rows = []
176
  for row in cropped_tensor:
177
  masked_row = row[row != token_to_remove]
178
  processed_rows.append(masked_row)
179
+
180
  code_lists = []
181
  for row in processed_rows:
182
  row_length = row.size(0)
183
  new_length = (row_length // 7) * 7
184
  trimmed_row = row[:new_length]
185
+ trimmed_row = [max(0, t - 128266) for t in trimmed_row]
186
  code_lists.append(trimmed_row)
 
 
 
 
 
187
 
188
+ return code_lists[0] if code_lists and len(code_lists[0]) > 0 else []
189
+
190
+ def redistribute_codes(code_list, snac_model):
191
+ if not code_list or len(code_list) < 7:
192
+ return np.zeros(12000) # 0.5 seconds of silence
193
+
194
+ device = next(snac_model.parameters()).device
195
  layer_1 = []
196
  layer_2 = []
197
  layer_3 = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  try:
200
+ for i in range((len(code_list))//7):
201
+ layer_1.append(max(0, code_list[7*i]))
202
+ layer_2.append(max(0, code_list[7*i+1]-4096))
203
+ layer_3.append(max(0, code_list[7*i+2]-(2*4096)))
204
+ layer_3.append(max(0, code_list[7*i+3]-(3*4096)))
205
+ layer_2.append(max(0, code_list[7*i+4]-(4*4096)))
206
+ layer_3.append(max(0, code_list[7*i+5]-(5*4096)))
207
+ layer_3.append(max(0, code_list[7*i+6]-(6*4096)))
208
+
209
+ codes = [
210
+ torch.tensor(layer_1, device=device).unsqueeze(0),
211
+ torch.tensor(layer_2, device=device).unsqueeze(0),
212
+ torch.tensor(layer_3, device=device).unsqueeze(0)
213
+ ]
214
+
215
+ with torch.no_grad():
216
+ audio_hat = snac_model.decode(codes)
217
+ return audio_hat.detach().squeeze().cpu().numpy()
218
+ except Exception as e:
219
+ print(f"Error in redistribute_codes: {e}")
220
+ return np.zeros(12000)
221
+
222
+ @spaces.GPU(duration=120)
223
+ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800, voice="Elise"):
224
+ """Generate speech for a single chunk"""
225
+ global model, tokenizer, snac_model
226
+
227
+ if not text_chunk.strip():
228
+ return np.array([])
229
+
230
+ try:
231
+ input_ids, attention_mask = process_prompt(text_chunk, voice, tokenizer, device)
232
 
 
233
  with torch.no_grad():
234
  generated_ids = model.generate(
235
  input_ids=input_ids,
 
241
  repetition_penalty=repetition_penalty,
242
  num_return_sequences=1,
243
  eos_token_id=128258,
244
+ pad_token_id=tokenizer.eos_token_id,
245
+ use_cache=True
246
  )
247
 
 
248
  code_list = parse_output(generated_ids)
249
 
250
+ if not code_list:
251
+ return np.array([])
252
+
253
  audio_samples = redistribute_codes(code_list, snac_model)
254
+ return audio_samples
255
+
256
+ except Exception as e:
257
+ print(f"Error generating speech chunk: {e}")
258
+ return np.array([])
259
+
260
+ def combine_audio_chunks(audio_chunks, pause_duration=0.3):
261
+ """Combine audio chunks with pauses between them"""
262
+ if not audio_chunks:
263
+ return np.array([])
264
+
265
+ # Create pause (silence)
266
+ pause_samples = int(24000 * pause_duration) # 24kHz sample rate
267
+ pause = np.zeros(pause_samples)
268
+
269
+ combined_audio = []
270
+ for i, chunk in enumerate(audio_chunks):
271
+ if len(chunk) > 0:
272
+ combined_audio.append(chunk)
273
+ # Add pause between chunks (except after the last chunk)
274
+ if i < len(audio_chunks) - 1:
275
+ combined_audio.append(pause)
276
+
277
+ if combined_audio:
278
+ return np.concatenate(combined_audio)
279
+ else:
280
+ return np.array([])
281
+
282
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800,
283
+ voice="Elise", split_method="punctuation", max_chars=200, max_tokens=150,
284
+ pause_duration=0.3, progress=gr.Progress()):
285
+ """Main function to generate speech with text splitting"""
286
+
287
+ if not text.strip():
288
+ return None
289
+
290
+ try:
291
+ # Split text based on selected method
292
+ progress(0.05, "Splitting text...")
293
+
294
+ if split_method == "punctuation":
295
+ text_chunks = split_text_by_punctuation(text, max_chars)
296
+ elif split_method == "tokens":
297
+ text_chunks = split_text_by_tokens(text, max_tokens)
298
+ else: # "none"
299
+ text_chunks = [text]
300
+
301
+ progress(0.1, f"Processing {len(text_chunks)} chunks...")
302
+ print(f"Split text into {len(text_chunks)} chunks:")
303
+ for i, chunk in enumerate(text_chunks):
304
+ print(f"Chunk {i+1}: {chunk[:50]}...")
305
+
306
+ # Generate audio for each chunk
307
+ audio_chunks = []
308
+ for i, chunk in enumerate(text_chunks):
309
+ progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
310
+
311
+ audio = generate_speech_chunk(
312
+ chunk, temperature, top_p, repetition_penalty, max_new_tokens, voice
313
+ )
314
+
315
+ if len(audio) > 0:
316
+ audio_chunks.append(audio)
317
+ print(f"Generated audio for chunk {i+1}: {len(audio)} samples ({len(audio)/24000:.2f}s)")
318
+
319
+ if not audio_chunks:
320
+ return None
321
+
322
+ # Combine all audio chunks
323
+ progress(0.9, "Combining audio chunks...")
324
+ final_audio = combine_audio_chunks(audio_chunks, pause_duration)
325
+
326
+ progress(1.0, "Complete!")
327
+ print(f"Final audio: {len(final_audio)} samples ({len(final_audio)/24000:.2f}s)")
328
+
329
+ return (24000, final_audio)
330
 
 
331
  except Exception as e:
332
  print(f"Error generating speech: {e}")
333
+ import traceback
334
+ traceback.print_exc()
335
  return None
336
+
337
+ # Examples
338
  examples = [
339
+ ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αŸ” αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ” αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ” αžαŸ’αž‰αž»αŸ†αžšαžŸαŸ‹αž“αŸ…αž€αŸ’αž“αž»αž„αž‘αžΈαž€αŸ’αžšαž»αž„αž—αŸ’αž“αŸ†αž–αŸαž‰ αž αžΎαž™αž˜αžΆαž“αž…αžšαžΆαž…αžšαžŽαŸ <gasp> αž…αŸ’αžšαžΎαž“αžŽαžΆαžŸαŸ‹αŸ”"],
340
+ ["αž€αžΆαžšαž“αž·αž™αžΆαž™αž“αŸ…αž…αŸ†αž–αŸ„αŸ‡αž˜αž»αžαžŸαžΆαž’αžΆαžšαžŽαŸˆ αž’αžΆαž…αž˜αžΆαž“αž€αžΆαžšαž–αž·αž”αžΆαž€αŸ” <groan> αž”αŸ‰αž»αž“αŸ’αžαŸ‚αž”αžΎαž αžΆαžαŸ‹ αž‚αŸαž’αžΆαž…αž’αŸ’αžœαžΎαž”αžΆαž“αŸ” αž–αŸαž›αžαŸ’αž›αŸ‡ αž–αŸαž›αžαŸ’αž‰αž»αŸ†αž“αž·αž™αžΆαž™αž…αŸ’αžšαžΎαž“αž–αŸαž€ αžαŸ’αž‰αž»αŸ†αžαŸ’αžšαžΌαžœ <cough> αžŸαž»αŸ†αž‘αŸ„αžŸαŸ” αžœαžΆαž‡αžΆαžšαžΏαž„αž’αž˜αŸ’αž˜αžαžΆαŸ”"],
 
 
 
341
  ]
342
+
 
 
343
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
344
+
345
  # Create Gradio interface
346
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
347
  gr.Markdown(f"""
348
  # 🎡 Khmer Text-to-Speech
349
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
 
350
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
 
351
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
352
+ ✨ **New**: Supports long text with automatic splitting!
353
+ """)
354
 
355
  text_input = gr.Textbox(
356
+ label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
357
+ placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡... (αž’αžΆαž…αžœαŸ‚αž„αž”αžΆαž“)",
358
+ lines=6
359
  )
360
 
361
+ # Text splitting options
362
+ with gr.Accordion("πŸ“ Text Splitting Options", open=True):
363
+ split_method = gr.Radio(
364
+ choices=[
365
+ ("Split by punctuation (recommended)", "punctuation"),
366
+ ("Split by token count", "tokens"),
367
+ ("No splitting", "none")
368
+ ],
369
+ value="punctuation",
370
+ label="Text splitting method",
371
+ info="For long texts, splitting helps avoid the 15s limit"
372
+ )
373
+
374
+ with gr.Row():
375
+ max_chars = gr.Slider(
376
+ minimum=50, maximum=500, value=200, step=25,
377
+ label="Max characters per chunk (punctuation mode)",
378
+ info="Shorter chunks = more natural breaks but more processing time"
379
+ )
380
+ max_tokens = gr.Slider(
381
+ minimum=50, maximum=300, value=150, step=25,
382
+ label="Max tokens per chunk (token mode)",
383
+ info="Controls chunk size based on model tokenization"
384
+ )
385
+
386
+ pause_duration = gr.Slider(
387
+ minimum=0.0, maximum=1.0, value=0.3, step=0.1,
388
+ label="Pause between chunks (seconds)",
389
+ info="Silence duration between text chunks"
390
+ )
391
 
392
  # Advanced Settings
393
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
394
  with gr.Row():
395
  temperature = gr.Slider(
396
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
397
+ label="Temperature",
398
  info="Higher values create more expressive speech"
399
  )
400
  top_p = gr.Slider(
401
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
402
+ label="Top P",
403
  info="Nucleus sampling threshold"
404
  )
405
  with gr.Row():
406
  repetition_penalty = gr.Slider(
407
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
408
+ label="Repetition Penalty",
409
  info="Higher values discourage repetitive patterns"
410
  )
411
  max_new_tokens = gr.Slider(
412
+ minimum=100, maximum=1200, value=800, step=100,
413
+ label="Max tokens per chunk",
414
+ info="Lower values for shorter, more reliable generation"
415
  )
416
 
417
  with gr.Row():
 
419
  clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
420
 
421
  audio_output = gr.Audio(
422
+ label="Generated Speech (αžŸαŸ†αž›αŸαž„αžŠαŸ‚αž›αž”αž„αŸ’αž€αžΎαžαž‘αžΎαž„)",
423
  type="numpy",
424
  show_label=True
425
  )
426
 
 
427
  gr.Examples(
428
  examples=examples,
429
  inputs=[text_input],
 
432
  cache_examples=False,
433
  )
434
 
 
435
  submit_btn.click(
436
  fn=generate_speech,
437
+ inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens,
438
+ gr.State("Elise"), split_method, max_chars, max_tokens, pause_duration],
439
  outputs=audio_output
440
  )
441
 
 
444
  inputs=[],
445
  outputs=[text_input, audio_output]
446
  )
447
+
448
  # Launch the app
449
  if __name__ == "__main__":
450
+ demo.queue(max_size=10).launch(
451
+ share=False,
452
+ server_name="0.0.0.0",
453
+ server_port=7860
454
+ )