mrrtmob commited on
Commit
5193c5e
·
1 Parent(s): 25f78d4
Files changed (1) hide show
  1. app.py +47 -38
app.py CHANGED
@@ -1,20 +1,35 @@
 
1
  import spaces
2
  from snac import SNAC
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
 
8
  load_dotenv()
 
 
 
 
 
 
 
 
 
9
  # Check if CUDA is available
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
  print("Loading SNAC model...")
12
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
  snac_model = snac_model.to(device)
 
14
  model_name = "mrrtmob/tts-khm-kore"
15
- # Download only model config and safetensors
 
16
  snapshot_download(
17
  repo_id=model_name,
 
18
  allow_patterns=[
19
  "config.json",
20
  "*.safetensors",
@@ -33,41 +48,47 @@ snapshot_download(
33
  "tokenizer.*"
34
  ]
35
  )
36
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
 
 
 
 
 
 
37
  model.to(device)
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
39
  print(f"Khmer TTS model loaded to {device}")
 
40
  # Process text prompt
41
  def process_prompt(prompt, voice, tokenizer, device):
42
  prompt = f"{voice}: {prompt}"
43
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
-
45
  start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
46
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
47
-
48
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
49
-
50
  # No padding needed for single input
51
  attention_mask = torch.ones_like(modified_input_ids)
52
-
53
  return modified_input_ids.to(device), attention_mask.to(device)
 
54
  # Parse output tokens to audio
55
  def parse_output(generated_ids):
56
  token_to_find = 128257
57
  token_to_remove = 128258
58
-
59
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
60
  if len(token_indices[1]) > 0:
61
  last_occurrence_idx = token_indices[1][-1].item()
62
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
  else:
64
  cropped_tensor = generated_ids
65
-
66
  processed_rows = []
67
  for row in cropped_tensor:
68
  masked_row = row[row != token_to_remove]
69
  processed_rows.append(masked_row)
70
-
71
  code_lists = []
72
  for row in processed_rows:
73
  row_length = row.size(0)
@@ -75,12 +96,11 @@ def parse_output(generated_ids):
75
  trimmed_row = row[:new_length]
76
  trimmed_row = [t - 128266 for t in trimmed_row]
77
  code_lists.append(trimmed_row)
78
-
79
  return code_lists[0] # Return just the first one for single sample
 
80
  # Redistribute codes for audio generation
81
  def redistribute_codes(code_list, snac_model):
82
  device = next(snac_model.parameters()).device # Get the device of SNAC model
83
-
84
  layer_1 = []
85
  layer_2 = []
86
  layer_3 = []
@@ -92,26 +112,23 @@ def redistribute_codes(code_list, snac_model):
92
  layer_2.append(code_list[7*i+4]-(4*4096))
93
  layer_3.append(code_list[7*i+5]-(5*4096))
94
  layer_3.append(code_list[7*i+6]-(6*4096))
95
-
96
  # Move tensors to the same device as the SNAC model
97
  codes = [
98
  torch.tensor(layer_1, device=device).unsqueeze(0),
99
  torch.tensor(layer_2, device=device).unsqueeze(0),
100
  torch.tensor(layer_3, device=device).unsqueeze(0)
101
  ]
102
-
103
  audio_hat = snac_model.decode(codes)
104
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
 
105
  # Main generation function
106
  @spaces.GPU()
107
  def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
108
  if not text.strip():
109
  return None
110
-
111
  try:
112
  progress(0.1, "Processing text...")
113
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
114
-
115
  progress(0.3, "Generating speech tokens...")
116
  with torch.no_grad():
117
  generated_ids = model.generate(
@@ -125,17 +142,15 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
125
  num_return_sequences=1,
126
  eos_token_id=128258,
127
  )
128
-
129
  progress(0.6, "Processing speech tokens...")
130
  code_list = parse_output(generated_ids)
131
-
132
  progress(0.8, "Converting to audio...")
133
  audio_samples = redistribute_codes(code_list, snac_model)
134
-
135
  return (24000, audio_samples) # Return sample rate and audio
136
  except Exception as e:
137
  print(f"Error generating speech: {e}")
138
  return None
 
139
  # Examples for the UI - Khmer text examples
140
  examples = [
141
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ Kiri ហើយខ្ញុំជា AI ដែលអាចបម្លែងអត្ថបទទៅជាសំលេង។"],
@@ -149,69 +164,64 @@ examples = [
149
  ["ខ្ញុំដើរទៅទិញអីញ៉ាំ ស្រាប់តែឃើញឆ្កែធំមួយរត់មករកខ្ញុំ។ <gasp> ខ្ញុំភ័យណាស់! តែវារត់ទៅបាត់វិញ។ <sigh>"], # I was walking to buy something when suddenly I saw a big dog running towards me. <gasp> I was so scared! But then it ran away. <sigh>
150
  ["អរគុណច្រើនសម្រាប់ជំនួយ។ <chuckle> បើគ្មានអ្នកទេ ខ្ញុំមិនដឹងធ្វើយ៉ាងម៉េចទេ។"],
151
  ]
 
152
  # Available voices (commented out for simpler UI)
153
  # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
 
154
  # Available Emotive Tags
155
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
 
156
  # Create Gradio interface
157
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
158
  gr.Markdown(f"""
159
  # 🎵 Khmer Text-to-Speech
160
  **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
161
-
162
  បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
163
-
164
  💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
165
  """)
166
-
167
  text_input = gr.Textbox(
168
- label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ)",
169
  placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
170
  lines=4
171
  )
172
-
173
  # Voice selector (commented out)
174
  # voice = gr.Dropdown(
175
- # choices=VOICES,
176
- # value="tara",
177
  # label="Voice (សំលេង)"
178
  # )
179
-
180
  # Advanced Settings
181
  with gr.Accordion("🔧 Advanced Settings", open=False):
182
  with gr.Row():
183
  temperature = gr.Slider(
184
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
185
- label="Temperature",
186
  info="Higher values create more expressive speech"
187
  )
188
  top_p = gr.Slider(
189
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
190
- label="Top P",
191
  info="Nucleus sampling threshold"
192
  )
193
  with gr.Row():
194
  repetition_penalty = gr.Slider(
195
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
196
- label="Repetition Penalty",
197
  info="Higher values discourage repetitive patterns"
198
  )
199
  max_new_tokens = gr.Slider(
200
  minimum=100, maximum=2000, value=1200, step=100,
201
- label="Max Length",
202
  info="Maximum length of generated audio"
203
  )
204
-
205
  with gr.Row():
206
  submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
207
  clear_btn = gr.Button("🗑️ Clear", size="lg")
208
-
209
  audio_output = gr.Audio(
210
- label="Generated Speech (សំលេងដែលបង្កើតឡើង)",
211
  type="numpy",
212
  show_label=True
213
  )
214
-
215
  # Set up examples (NO CACHE)
216
  gr.Examples(
217
  examples=examples,
@@ -220,19 +230,18 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
220
  fn=lambda text: generate_speech(text),
221
  cache_examples=False,
222
  )
223
-
224
  # Set up event handlers
225
  submit_btn.click(
226
  fn=generate_speech,
227
  inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
228
  outputs=audio_output
229
  )
230
-
231
  clear_btn.click(
232
  fn=lambda: (None, None),
233
  inputs=[],
234
  outputs=[text_input, audio_output]
235
  )
 
236
  # Launch the app
237
  if __name__ == "__main__":
238
  demo.queue().launch(share=False)
 
1
+ import os
2
  import spaces
3
  from snac import SNAC
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from huggingface_hub import snapshot_download, login
8
  from dotenv import load_dotenv
9
+
10
  load_dotenv()
11
+
12
+ # Get HF token from environment variables
13
+ hf_token = os.getenv("HF_TOKEN")
14
+ if hf_token:
15
+ login(token=hf_token)
16
+ print("Successfully logged in to Hugging Face")
17
+ else:
18
+ print("Warning: HF_TOKEN not found in environment variables")
19
+
20
  # Check if CUDA is available
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
  print("Loading SNAC model...")
24
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
25
  snac_model = snac_model.to(device)
26
+
27
  model_name = "mrrtmob/tts-khm-kore"
28
+
29
+ # Download only model config and safetensors with token
30
  snapshot_download(
31
  repo_id=model_name,
32
+ token=hf_token, # Add token here
33
  allow_patterns=[
34
  "config.json",
35
  "*.safetensors",
 
48
  "tokenizer.*"
49
  ]
50
  )
51
+
52
+ # Load model and tokenizer with token
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch.bfloat16,
56
+ token=hf_token # Add token here
57
+ )
58
  model.to(device)
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ model_name,
62
+ token=hf_token # Add token here
63
+ )
64
+
65
  print(f"Khmer TTS model loaded to {device}")
66
+
67
  # Process text prompt
68
  def process_prompt(prompt, voice, tokenizer, device):
69
  prompt = f"{voice}: {prompt}"
70
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
71
  start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
72
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
 
73
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
 
74
  # No padding needed for single input
75
  attention_mask = torch.ones_like(modified_input_ids)
 
76
  return modified_input_ids.to(device), attention_mask.to(device)
77
+
78
  # Parse output tokens to audio
79
  def parse_output(generated_ids):
80
  token_to_find = 128257
81
  token_to_remove = 128258
 
82
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
83
  if len(token_indices[1]) > 0:
84
  last_occurrence_idx = token_indices[1][-1].item()
85
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
86
  else:
87
  cropped_tensor = generated_ids
 
88
  processed_rows = []
89
  for row in cropped_tensor:
90
  masked_row = row[row != token_to_remove]
91
  processed_rows.append(masked_row)
 
92
  code_lists = []
93
  for row in processed_rows:
94
  row_length = row.size(0)
 
96
  trimmed_row = row[:new_length]
97
  trimmed_row = [t - 128266 for t in trimmed_row]
98
  code_lists.append(trimmed_row)
 
99
  return code_lists[0] # Return just the first one for single sample
100
+
101
  # Redistribute codes for audio generation
102
  def redistribute_codes(code_list, snac_model):
103
  device = next(snac_model.parameters()).device # Get the device of SNAC model
 
104
  layer_1 = []
105
  layer_2 = []
106
  layer_3 = []
 
112
  layer_2.append(code_list[7*i+4]-(4*4096))
113
  layer_3.append(code_list[7*i+5]-(5*4096))
114
  layer_3.append(code_list[7*i+6]-(6*4096))
 
115
  # Move tensors to the same device as the SNAC model
116
  codes = [
117
  torch.tensor(layer_1, device=device).unsqueeze(0),
118
  torch.tensor(layer_2, device=device).unsqueeze(0),
119
  torch.tensor(layer_3, device=device).unsqueeze(0)
120
  ]
 
121
  audio_hat = snac_model.decode(codes)
122
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
123
+
124
  # Main generation function
125
  @spaces.GPU()
126
  def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
127
  if not text.strip():
128
  return None
 
129
  try:
130
  progress(0.1, "Processing text...")
131
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
132
  progress(0.3, "Generating speech tokens...")
133
  with torch.no_grad():
134
  generated_ids = model.generate(
 
142
  num_return_sequences=1,
143
  eos_token_id=128258,
144
  )
 
145
  progress(0.6, "Processing speech tokens...")
146
  code_list = parse_output(generated_ids)
 
147
  progress(0.8, "Converting to audio...")
148
  audio_samples = redistribute_codes(code_list, snac_model)
 
149
  return (24000, audio_samples) # Return sample rate and audio
150
  except Exception as e:
151
  print(f"Error generating speech: {e}")
152
  return None
153
+
154
  # Examples for the UI - Khmer text examples
155
  examples = [
156
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ Kiri ហើយខ្ញុំជា AI ដែលអាចបម្លែងអត្ថបទទៅជាសំលេង។"],
 
164
  ["ខ្ញុំដើរទៅទិញអីញ៉ាំ ស្រាប់តែឃើញឆ្កែធំមួយរត់មករកខ្ញុំ។ <gasp> ខ្ញុំភ័យណាស់! តែវារត់ទៅបាត់វិញ។ <sigh>"], # I was walking to buy something when suddenly I saw a big dog running towards me. <gasp> I was so scared! But then it ran away. <sigh>
165
  ["អរគុណច្រើនសម្រាប់ជំនួយ។ <chuckle> បើគ្មានអ្នកទេ ខ្ញុំមិនដឹងធ្វើយ៉ាងម៉េចទេ។"],
166
  ]
167
+
168
  # Available voices (commented out for simpler UI)
169
  # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
170
+
171
  # Available Emotive Tags
172
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
173
+
174
  # Create Gradio interface
175
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
176
  gr.Markdown(f"""
177
  # 🎵 Khmer Text-to-Speech
178
  **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
 
179
  បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
 
180
  💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
181
  """)
 
182
  text_input = gr.Textbox(
183
+ label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ)",
184
  placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
185
  lines=4
186
  )
 
187
  # Voice selector (commented out)
188
  # voice = gr.Dropdown(
189
+ # choices=VOICES,
190
+ # value="tara",
191
  # label="Voice (សំលេង)"
192
  # )
 
193
  # Advanced Settings
194
  with gr.Accordion("🔧 Advanced Settings", open=False):
195
  with gr.Row():
196
  temperature = gr.Slider(
197
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
198
+ label="Temperature",
199
  info="Higher values create more expressive speech"
200
  )
201
  top_p = gr.Slider(
202
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
203
+ label="Top P",
204
  info="Nucleus sampling threshold"
205
  )
206
  with gr.Row():
207
  repetition_penalty = gr.Slider(
208
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
209
+ label="Repetition Penalty",
210
  info="Higher values discourage repetitive patterns"
211
  )
212
  max_new_tokens = gr.Slider(
213
  minimum=100, maximum=2000, value=1200, step=100,
214
+ label="Max Length",
215
  info="Maximum length of generated audio"
216
  )
 
217
  with gr.Row():
218
  submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
219
  clear_btn = gr.Button("🗑️ Clear", size="lg")
 
220
  audio_output = gr.Audio(
221
+ label="Generated Speech (សំលេងដែលបង្កើតឡើង)",
222
  type="numpy",
223
  show_label=True
224
  )
 
225
  # Set up examples (NO CACHE)
226
  gr.Examples(
227
  examples=examples,
 
230
  fn=lambda text: generate_speech(text),
231
  cache_examples=False,
232
  )
 
233
  # Set up event handlers
234
  submit_btn.click(
235
  fn=generate_speech,
236
  inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
237
  outputs=audio_output
238
  )
 
239
  clear_btn.click(
240
  fn=lambda: (None, None),
241
  inputs=[],
242
  outputs=[text_input, audio_output]
243
  )
244
+
245
  # Launch the app
246
  if __name__ == "__main__":
247
  demo.queue().launch(share=False)