mrrtmob commited on
Commit
0100eae
·
1 Parent(s): 8a44aae

Update Khmer TTS model to version 2 and simplify UI components

Browse files
Files changed (1) hide show
  1. app.py +42 -81
app.py CHANGED
@@ -6,15 +6,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
-
10
  # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print("Loading SNAC model...")
13
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
14
  snac_model = snac_model.to(device)
15
-
16
- model_name = "mrrtmob/tts-khm-1"
17
-
18
  # Download only model config and safetensors
19
  snapshot_download(
20
  repo_id=model_name,
@@ -36,12 +33,10 @@ snapshot_download(
36
  "tokenizer.*"
37
  ]
38
  )
39
-
40
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
41
  model.to(device)
42
  tokenizer = AutoTokenizer.from_pretrained(model_name)
43
  print(f"Khmer TTS model loaded to {device}")
44
-
45
  # Process text prompt
46
  def process_prompt(prompt, voice, tokenizer, device):
47
  prompt = f"{voice}: {prompt}"
@@ -56,7 +51,6 @@ def process_prompt(prompt, voice, tokenizer, device):
56
  attention_mask = torch.ones_like(modified_input_ids)
57
 
58
  return modified_input_ids.to(device), attention_mask.to(device)
59
-
60
  # Parse output tokens to audio
61
  def parse_output(generated_ids):
62
  token_to_find = 128257
@@ -83,7 +77,6 @@ def parse_output(generated_ids):
83
  code_lists.append(trimmed_row)
84
 
85
  return code_lists[0] # Return just the first one for single sample
86
-
87
  # Redistribute codes for audio generation
88
  def redistribute_codes(code_list, snac_model):
89
  device = next(snac_model.parameters()).device # Get the device of SNAC model
@@ -109,10 +102,9 @@ def redistribute_codes(code_list, snac_model):
109
 
110
  audio_hat = snac_model.decode(codes)
111
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
112
-
113
  # Main generation function
114
  @spaces.GPU()
115
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
116
  if not text.strip():
117
  return None
118
 
@@ -144,95 +136,65 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
144
  except Exception as e:
145
  print(f"Error generating speech: {e}")
146
  return None
147
-
148
- # Examples for the UI - Khmer text examples
149
  examples = [
150
- ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
151
- ["ខ្ញុំអាចបង្កើតសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច ឬ <sigh> ថប់ដង្ហើម។", "dan", 0.7, 0.95, 1.1, 1200],
152
- ["ខ្ញុំរស់នៅក្នុងទីក្រុងភ្នំពេញ ហើយមានប៉ារ៉ាម៉ែត្រ <gasp> ច្រើនណាស់។", "leah", 0.6, 0.9, 1.2, 1200],
153
- ["ពេលខ្លះ ពេលខ្ញុំនិយាយច្រើនពេក ខ្ញុំត្រូវ <cough> សុំទោស។", "leo", 0.65, 0.9, 1.1, 1200],
154
- ["ការនិយាយនៅចំពោះមុខសាធារណៈ អាចមានការពិបាក។ <groan> ប៉ុន្តែបើហាត់ហាន គេអាចធ្វើបាន។", "jess", 0.7, 0.95, 1.1, 1200],
155
- ["ការឡើងភ្នំពិតជាហត់ណត់ ប៉ុន្តែទេសភាពពីលើនេះ ពិតជាស្រស់ស្អាត! <sigh> គួរឱ្យធ្វើ។", "mia", 0.65, 0.9, 1.15, 1200],
156
- ["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
157
- ["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
158
  ]
159
-
160
- # Available voices
161
- VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
162
-
163
  # Available Emotive Tags
164
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
165
-
166
- # Create Gradio interface
167
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
168
  gr.Markdown(f"""
169
- # 🎵 Khmer Text-to-Speech (ម៉ូដែលបម្លែងអត្ថបទជាសំលេង)
170
- Enter your Khmer text below and hear it converted to natural-sounding speech.
171
 
172
- បញ្ចូលអត្ថបទខ្មែររបស់អ្នកខាងក្រោម ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយធម្មជាតិ។
173
 
174
- ## Tips for better prompts (គន្លឹះសម្រាប់ការប្រើប្រាស់ដ៏ល្អ):
175
- - Add paralinguistic elements like {", ".join(EMOTIVE_TAGS)} for more human-like speech
176
- - Longer text prompts generally work better than very short phrases
177
- - អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
178
- - Increasing `repetition_penalty` and `temperature` makes the model speak faster
179
  """)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with gr.Row():
182
- with gr.Column(scale=3):
183
- text_input = gr.Textbox(
184
- label="Text to speak (អត្ថបទដើម្បីនិយាយ)",
185
- placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
186
- lines=5
187
- )
188
- voice = gr.Dropdown(
189
- choices=VOICES,
190
- value="tara",
191
- label="Voice (សំលេង)"
192
- )
193
-
194
- with gr.Accordion("Advanced Settings (ការកំណត់កម្រិតខ្ពស់)", open=False):
195
- temperature = gr.Slider(
196
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
197
- label="Temperature",
198
- info="Higher values (0.7-1.0) create more expressive but less stable speech"
199
- )
200
- top_p = gr.Slider(
201
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
202
- label="Top P",
203
- info="Nucleus sampling threshold"
204
- )
205
- repetition_penalty = gr.Slider(
206
- minimum=1.0, maximum=2.0, value=1.1, step=0.05,
207
- label="Repetition Penalty",
208
- info="Higher values discourage repetitive patterns"
209
- )
210
- max_new_tokens = gr.Slider(
211
- minimum=100, maximum=2000, value=1200, step=100,
212
- label="Max Length",
213
- info="Maximum length of generated audio (in tokens)"
214
- )
215
-
216
- with gr.Row():
217
- submit_btn = gr.Button("Generate Speech (បង្កើតសំលេង)", variant="primary")
218
- clear_btn = gr.Button("Clear (លុប)")
219
-
220
- with gr.Column(scale=2):
221
- audio_output = gr.Audio(label="Generated Speech (សំលេងដែលបង្កើតឡើង)", type="numpy")
222
-
223
- # Set up examples
224
  gr.Examples(
225
  examples=examples,
226
- inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
227
  outputs=audio_output,
228
- fn=generate_speech,
229
  cache_examples=True,
230
  )
231
 
232
  # Set up event handlers
233
  submit_btn.click(
234
- fn=generate_speech,
235
- inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
236
  outputs=audio_output
237
  )
238
 
@@ -241,7 +203,6 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
241
  inputs=[],
242
  outputs=[text_input, audio_output]
243
  )
244
-
245
  # Launch the app
246
  if __name__ == "__main__":
247
  demo.queue().launch(share=False)
 
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  load_dotenv()
 
9
  # Check if CUDA is available
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print("Loading SNAC model...")
12
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
  snac_model = snac_model.to(device)
14
+ model_name = "mrrtmob/tts-khm-2"
 
 
15
  # Download only model config and safetensors
16
  snapshot_download(
17
  repo_id=model_name,
 
33
  "tokenizer.*"
34
  ]
35
  )
 
36
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
  model.to(device)
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  print(f"Khmer TTS model loaded to {device}")
 
40
  # Process text prompt
41
  def process_prompt(prompt, voice, tokenizer, device):
42
  prompt = f"{voice}: {prompt}"
 
51
  attention_mask = torch.ones_like(modified_input_ids)
52
 
53
  return modified_input_ids.to(device), attention_mask.to(device)
 
54
  # Parse output tokens to audio
55
  def parse_output(generated_ids):
56
  token_to_find = 128257
 
77
  code_lists.append(trimmed_row)
78
 
79
  return code_lists[0] # Return just the first one for single sample
 
80
  # Redistribute codes for audio generation
81
  def redistribute_codes(code_list, snac_model):
82
  device = next(snac_model.parameters()).device # Get the device of SNAC model
 
102
 
103
  audio_hat = snac_model.decode(codes)
104
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
 
105
  # Main generation function
106
  @spaces.GPU()
107
+ def generate_speech(text, voice="tara", temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, progress=gr.Progress()):
108
  if not text.strip():
109
  return None
110
 
 
136
  except Exception as e:
137
  print(f"Error generating speech: {e}")
138
  return None
139
+ # Examples for the UI - Khmer text examples (simplified)
 
140
  examples = [
141
+ ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។"],
142
+ ["ខ្ញុំអាចបង្កើតសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច ឬ <sigh> ថប់ដង្ហើម។"],
143
+ ["ខ្ញុំរស់នៅក្នុងទីក្រុងភ្នំពេញ ហើយមានប៉ារ៉ាម៉ែត្រ <gasp> ច្រើនណាស់។"],
144
+ ["ពេលខ្លះ ពេលខ្ញុំនិយាយច្រើនពេក ខ្ញុំត្រូវ <cough> សុំទោស។"],
145
+ ["ការនិយាយនៅចំពោះមុខសាធារណៈ អាចមានការពិបាក។ <groan> ប៉ុន្តែប��ហាត់ហាន គេអាចធ្វើបាន។"],
 
 
 
146
  ]
147
+ # Available voices (commented out for simpler UI)
148
+ # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
 
 
149
  # Available Emotive Tags
150
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
151
+ # Create Gradio interface (simplified)
 
152
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
153
  gr.Markdown(f"""
154
+ # 🎵 Khmer Text-to-Speech
155
+ **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
156
 
157
+ បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
158
 
159
+ 💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
 
 
 
 
160
  """)
161
 
162
+ text_input = gr.Textbox(
163
+ label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ)",
164
+ placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
165
+ lines=4
166
+ )
167
+
168
+ # Voice selector (commented out)
169
+ # voice = gr.Dropdown(
170
+ # choices=VOICES,
171
+ # value="tara",
172
+ # label="Voice (សំលេង)"
173
+ # )
174
+
175
  with gr.Row():
176
+ submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
177
+ clear_btn = gr.Button("🗑️ Clear", size="lg")
178
+
179
+ audio_output = gr.Audio(
180
+ label="Generated Speech (សំលេងដែលបង្កើតឡើង)",
181
+ type="numpy",
182
+ show_label=True
183
+ )
184
+
185
+ # Set up examples (simplified)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  gr.Examples(
187
  examples=examples,
188
+ inputs=[text_input],
189
  outputs=audio_output,
190
+ fn=lambda text: generate_speech(text),
191
  cache_examples=True,
192
  )
193
 
194
  # Set up event handlers
195
  submit_btn.click(
196
+ fn=lambda text: generate_speech(text),
197
+ inputs=[text_input],
198
  outputs=audio_output
199
  )
200
 
 
203
  inputs=[],
204
  outputs=[text_input, audio_output]
205
  )
 
206
  # Launch the app
207
  if __name__ == "__main__":
208
  demo.queue().launch(share=False)