mrrtmob commited on
Commit
e173776
Β·
1 Parent(s): 731d214

Refactor authentication and model loading; simplify UI and reduce resource usage

Browse files
Files changed (1) hide show
  1. app.py +28 -104
app.py CHANGED
@@ -12,36 +12,19 @@ load_dotenv()
12
  # Get HF token from environment variables
13
  hf_token = os.getenv("HF_TOKEN")
14
 
15
- # Debug and authentication
16
- print("=== DEBUG INFO ===")
17
- print(f"HF_TOKEN exists: {bool(hf_token)}")
18
-
19
  if hf_token:
20
  login(token=hf_token)
21
- try:
22
- user_info = whoami(token=hf_token)
23
- print(f"Successfully logged in as: {user_info.get('name', 'Unknown')}")
24
- print(f"User type: {user_info.get('type', 'Unknown')}")
25
- print(f"User ID: {user_info.get('id', 'Unknown')}")
26
- except Exception as e:
27
- print(f"Authentication error: {e}")
28
- else:
29
- print("Warning: HF_TOKEN not found in environment variables")
30
-
31
- print("=== END DEBUG ===")
32
 
33
  # Check if CUDA is available
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
- print(f"Using device: {device}")
36
 
37
  print("Loading SNAC model...")
38
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
39
  snac_model = snac_model.to(device)
40
- print("SNAC model loaded successfully")
41
 
42
  model_name = "mrrtmob/tts-khm-kore"
43
 
44
- print(f"Downloading model files from {model_name}...")
45
  # Download only model config and safetensors with token
46
  snapshot_download(
47
  repo_id=model_name,
@@ -63,18 +46,15 @@ snapshot_download(
63
  "scheduler.pt"
64
  ]
65
  )
66
- print("Model files downloaded successfully")
67
 
68
- print("Loading main model...")
69
- # Load model and tokenizer with token (removed device_map)
70
  model = AutoModelForCausalLM.from_pretrained(
71
  model_name,
72
  torch_dtype=torch.bfloat16,
73
  token=hf_token
74
  )
75
- model = model.to(device) # Move to device manually
76
 
77
- print("Loading tokenizer...")
78
  tokenizer = AutoTokenizer.from_pretrained(
79
  model_name,
80
  token=hf_token
@@ -114,14 +94,14 @@ def parse_output(generated_ids):
114
  trimmed_row = row[:new_length]
115
  trimmed_row = [t - 128266 for t in trimmed_row]
116
  code_lists.append(trimmed_row)
117
- return code_lists[0] if code_lists else [] # Return just the first one for single sample
118
 
119
  # Redistribute codes for audio generation
120
  def redistribute_codes(code_list, snac_model):
121
  if not code_list:
122
  return None
123
 
124
- device = next(snac_model.parameters()).device # Get the device of SNAC model
125
  layer_1 = []
126
  layer_2 = []
127
  layer_3 = []
@@ -145,26 +125,22 @@ def redistribute_codes(code_list, snac_model):
145
  if not layer_1:
146
  return None
147
 
148
- # Move tensors to the same device as the SNAC model
149
  codes = [
150
  torch.tensor(layer_1, device=device).unsqueeze(0),
151
  torch.tensor(layer_2, device=device).unsqueeze(0),
152
  torch.tensor(layer_3, device=device).unsqueeze(0)
153
  ]
154
  audio_hat = snac_model.decode(codes)
155
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
156
 
157
- # Main generation function
158
- @spaces.GPU(duration=120)
159
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
160
  if not text.strip():
161
- gr.Warning("Please enter some text to generate speech.")
162
  return None
163
 
164
  try:
165
  progress(0.1, "Processing text...")
166
- print(f"Generating speech for text: {text[:50]}...")
167
-
168
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
169
 
170
  progress(0.3, "Generating speech tokens...")
@@ -186,26 +162,21 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
186
  code_list = parse_output(generated_ids)
187
 
188
  if not code_list:
189
- gr.Warning("Failed to generate valid audio codes.")
190
  return None
191
 
192
  progress(0.8, "Converting to audio...")
193
  audio_samples = redistribute_codes(code_list, snac_model)
194
 
195
  if audio_samples is None:
196
- gr.Warning("Failed to convert codes to audio.")
197
  return None
198
 
199
- print("Speech generation completed successfully")
200
- return (24000, audio_samples) # Return sample rate and audio
201
 
202
  except Exception as e:
203
- error_msg = f"Error generating speech: {str(e)}"
204
- print(error_msg)
205
- gr.Error(error_msg)
206
  return None
207
 
208
- # Examples for the UI - Khmer text examples
209
  examples = [
210
  ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‡αžΆ AI αžŠαŸ‚αž›αž’αžΆαž…αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αŸ”"],
211
  ["αžαŸ’αž‰αž»οΏ½οΏ½αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž…αŸ”"],
@@ -219,49 +190,17 @@ examples = [
219
  ["αž’αžšαž‚αž»αžŽαž…αŸ’αžšαžΎαž“αžŸαž˜αŸ’αžšαžΆαž”αŸ‹αž‡αŸ†αž“αž½αž™αŸ” <chuckle> αž”αžΎαž‚αŸ’αž˜αžΆαž“αž’αŸ’αž“αž€αž‘αŸ αžαŸ’αž‰αž»αŸ†αž˜αž·αž“αžŠαžΉαž„αž’αŸ’αžœαžΎαž™αŸ‰αžΆαž„αž˜αŸ‰αŸαž…αž‘αŸαŸ”"],
220
  ]
221
 
222
- # Available voices (commented out for simpler UI)
223
- # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
224
-
225
- # Available Emotive Tags
226
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
227
 
228
- # Create custom CSS
229
- css = """
230
- .gradio-container {
231
- max-width: 1200px;
232
- margin: auto;
233
- padding-top: 1.5rem;
234
- }
235
- .main-header {
236
- text-align: center;
237
- margin-bottom: 2rem;
238
- }
239
- .generate-btn {
240
- background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
241
- border: none !important;
242
- color: white !important;
243
- font-weight: bold !important;
244
- }
245
- .clear-btn {
246
- background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important;
247
- border: none !important;
248
- color: white !important;
249
- }
250
- """
251
-
252
  # Create Gradio interface
253
- with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo:
254
  gr.Markdown(f"""
255
- <div class="main-header">
256
-
257
  # 🎡 Khmer Text-to-Speech
258
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
259
 
260
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
261
 
262
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
263
-
264
- </div>
265
  """)
266
 
267
  with gr.Row():
@@ -270,7 +209,7 @@ with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as
270
  label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
271
  placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡...",
272
  lines=4,
273
- max_lines=8
274
  )
275
 
276
  # Advanced Settings
@@ -278,29 +217,25 @@ with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as
278
  with gr.Row():
279
  temperature = gr.Slider(
280
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
281
- label="Temperature",
282
- info="Higher values create more expressive speech"
283
  )
284
  top_p = gr.Slider(
285
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
286
- label="Top P",
287
- info="Nucleus sampling threshold"
288
  )
289
  with gr.Row():
290
  repetition_penalty = gr.Slider(
291
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
292
- label="Repetition Penalty",
293
- info="Higher values discourage repetitive patterns"
294
  )
295
  max_new_tokens = gr.Slider(
296
- minimum=100, maximum=2000, value=1200, step=100,
297
- label="Max Length",
298
- info="Maximum length of generated audio"
299
  )
300
 
301
  with gr.Row():
302
- submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"])
303
- clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg", elem_classes=["clear-btn"])
304
 
305
  with gr.Column(scale=1):
306
  audio_output = gr.Audio(
@@ -310,14 +245,14 @@ with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as
310
  interactive=False
311
  )
312
 
313
- # Set up examples (NO CACHE)
314
  gr.Examples(
315
  examples=examples,
316
  inputs=[text_input],
317
  outputs=audio_output,
318
  fn=lambda text: generate_speech(text),
319
- cache_examples=False,
320
- label="πŸ“ Example Texts (αž’αžαŸ’αžαž”αž‘αž‚αŸ†αžšαžΌ)"
321
  )
322
 
323
  # Set up event handlers
@@ -333,25 +268,14 @@ with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as
333
  inputs=[],
334
  outputs=[text_input, audio_output]
335
  )
336
-
337
- # Add keyboard shortcut
338
- text_input.submit(
339
- fn=generate_speech,
340
- inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
341
- outputs=audio_output,
342
- show_progress=True
343
- )
344
 
345
- # Launch the app
346
  if __name__ == "__main__":
347
- print("Starting Gradio interface...")
348
  demo.queue(
349
- max_size=20,
350
- default_concurrency_limit=5
351
  ).launch(
352
- server_name="0.0.0.0",
353
- server_port=7860,
354
  share=False,
355
  show_error=True,
356
- quiet=False
357
  )
 
12
  # Get HF token from environment variables
13
  hf_token = os.getenv("HF_TOKEN")
14
 
15
+ # Simplified authentication - no debug prints
 
 
 
16
  if hf_token:
17
  login(token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Check if CUDA is available
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
 
22
  print("Loading SNAC model...")
23
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
24
  snac_model = snac_model.to(device)
 
25
 
26
  model_name = "mrrtmob/tts-khm-kore"
27
 
 
28
  # Download only model config and safetensors with token
29
  snapshot_download(
30
  repo_id=model_name,
 
46
  "scheduler.pt"
47
  ]
48
  )
 
49
 
50
+ # Load model and tokenizer with token
 
51
  model = AutoModelForCausalLM.from_pretrained(
52
  model_name,
53
  torch_dtype=torch.bfloat16,
54
  token=hf_token
55
  )
56
+ model = model.to(device)
57
 
 
58
  tokenizer = AutoTokenizer.from_pretrained(
59
  model_name,
60
  token=hf_token
 
94
  trimmed_row = row[:new_length]
95
  trimmed_row = [t - 128266 for t in trimmed_row]
96
  code_lists.append(trimmed_row)
97
+ return code_lists[0] if code_lists else []
98
 
99
  # Redistribute codes for audio generation
100
  def redistribute_codes(code_list, snac_model):
101
  if not code_list:
102
  return None
103
 
104
+ device = next(snac_model.parameters()).device
105
  layer_1 = []
106
  layer_2 = []
107
  layer_3 = []
 
125
  if not layer_1:
126
  return None
127
 
 
128
  codes = [
129
  torch.tensor(layer_1, device=device).unsqueeze(0),
130
  torch.tensor(layer_2, device=device).unsqueeze(0),
131
  torch.tensor(layer_3, device=device).unsqueeze(0)
132
  ]
133
  audio_hat = snac_model.decode(codes)
134
+ return audio_hat.detach().squeeze().cpu().numpy()
135
 
136
+ # Main generation function - KEY CHANGES HERE
137
+ @spaces.GPU(duration=60) # Reduced duration
138
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800, voice="Elise", progress=gr.Progress()): # Reduced max tokens
139
  if not text.strip():
 
140
  return None
141
 
142
  try:
143
  progress(0.1, "Processing text...")
 
 
144
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
145
 
146
  progress(0.3, "Generating speech tokens...")
 
162
  code_list = parse_output(generated_ids)
163
 
164
  if not code_list:
 
165
  return None
166
 
167
  progress(0.8, "Converting to audio...")
168
  audio_samples = redistribute_codes(code_list, snac_model)
169
 
170
  if audio_samples is None:
 
171
  return None
172
 
173
+ return (24000, audio_samples)
 
174
 
175
  except Exception as e:
176
+ print(f"Error generating speech: {e}")
 
 
177
  return None
178
 
179
+ # Examples - reduced to save quota
180
  examples = [
181
  ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αž αžΎαž™αžαŸ’αž‰αž»αŸ†αž‡αžΆ AI αžŠαŸ‚αž›αž’αžΆαž…αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αŸ”"],
182
  ["αžαŸ’αž‰αž»οΏ½οΏ½αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž…αŸ”"],
 
190
  ["αž’αžšαž‚αž»αžŽαž…αŸ’αžšαžΎαž“αžŸαž˜αŸ’αžšαžΆαž”αŸ‹αž‡αŸ†αž“αž½αž™αŸ” <chuckle> αž”αžΎαž‚αŸ’αž˜αžΆαž“αž’αŸ’αž“αž€αž‘αŸ αžαŸ’αž‰αž»αŸ†αž˜αž·αž“αžŠαžΉαž„αž’αŸ’αžœαžΎαž™αŸ‰αžΆαž„αž˜αŸ‰αŸαž…αž‘αŸαŸ”"],
191
  ]
192
 
 
 
 
 
193
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # Create Gradio interface
196
+ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
197
  gr.Markdown(f"""
 
 
198
  # 🎡 Khmer Text-to-Speech
199
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
200
 
201
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
202
 
203
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
 
 
204
  """)
205
 
206
  with gr.Row():
 
209
  label="Enter Khmer text (αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš)",
210
  placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€αž“αŸ…αž‘αžΈαž“αŸαŸ‡...",
211
  lines=4,
212
+ max_lines=6 # Limited input size
213
  )
214
 
215
  # Advanced Settings
 
217
  with gr.Row():
218
  temperature = gr.Slider(
219
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
220
+ label="Temperature"
 
221
  )
222
  top_p = gr.Slider(
223
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
224
+ label="Top P"
 
225
  )
226
  with gr.Row():
227
  repetition_penalty = gr.Slider(
228
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
229
+ label="Repetition Penalty"
 
230
  )
231
  max_new_tokens = gr.Slider(
232
+ minimum=100, maximum=800, value=800, step=50, # Reduced max
233
+ label="Max Length"
 
234
  )
235
 
236
  with gr.Row():
237
+ submit_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg")
238
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
239
 
240
  with gr.Column(scale=1):
241
  audio_output = gr.Audio(
 
245
  interactive=False
246
  )
247
 
248
+ # Examples with NO CACHE to save quota
249
  gr.Examples(
250
  examples=examples,
251
  inputs=[text_input],
252
  outputs=audio_output,
253
  fn=lambda text: generate_speech(text),
254
+ cache_examples=False, # Important: no caching
255
+ label="πŸ“ Example Texts"
256
  )
257
 
258
  # Set up event handlers
 
268
  inputs=[],
269
  outputs=[text_input, audio_output]
270
  )
 
 
 
 
 
 
 
 
271
 
272
+ # Launch with optimizations
273
  if __name__ == "__main__":
 
274
  demo.queue(
275
+ max_size=10, # Reduced queue size
276
+ default_concurrency_limit=2 # Reduced concurrent users
277
  ).launch(
 
 
278
  share=False,
279
  show_error=True,
280
+ ssr_mode=False # Important for quota
281
  )