mrrtmob commited on
Commit
3d2ce0a
·
1 Parent(s): 63b6422

Implement rate limiting for speech generation and enhance text validation; improve UI with character count and custom CSS

Browse files
Files changed (1) hide show
  1. app.py +149 -27
app.py CHANGED
@@ -1,18 +1,40 @@
1
  import os
 
 
2
  import spaces
3
  from snac import SNAC
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from huggingface_hub import snapshot_download, login, whoami
8
  from dotenv import load_dotenv
9
 
10
  load_dotenv()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Get HF token from environment variables
13
  hf_token = os.getenv("HF_TOKEN")
14
 
15
- # Simplified authentication - no debug prints
16
  if hf_token:
17
  login(token=hf_token)
18
 
@@ -22,9 +44,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print("Loading SNAC model...")
23
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
24
  snac_model = snac_model.to(device)
 
25
 
26
  model_name = "mrrtmob/tts-khm-kore"
27
 
 
28
  # Download only model config and safetensors with token
29
  snapshot_download(
30
  repo_id=model_name,
@@ -46,7 +70,9 @@ snapshot_download(
46
  "scheduler.pt"
47
  ]
48
  )
 
49
 
 
50
  # Load model and tokenizer with token
51
  model = AutoModelForCausalLM.from_pretrained(
52
  model_name,
@@ -55,6 +81,7 @@ model = AutoModelForCausalLM.from_pretrained(
55
  )
56
  model = model.to(device)
57
 
 
58
  tokenizer = AutoTokenizer.from_pretrained(
59
  model_name,
60
  token=hf_token
@@ -133,14 +160,43 @@ def redistribute_codes(code_list, snac_model):
133
  audio_hat = snac_model.decode(codes)
134
  return audio_hat.detach().squeeze().cpu().numpy()
135
 
136
- # Main generation function - KEY CHANGES HERE
137
- @spaces.GPU(duration=60) # Reduced duration
138
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800, voice="Elise", progress=gr.Progress()): # Reduced max tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  if not text.strip():
 
140
  return None
141
 
 
 
 
142
  try:
143
  progress(0.1, "Processing text...")
 
 
144
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
145
 
146
  progress(0.3, "Generating speech tokens...")
@@ -162,21 +218,26 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
162
  code_list = parse_output(generated_ids)
163
 
164
  if not code_list:
 
165
  return None
166
 
167
  progress(0.8, "Converting to audio...")
168
  audio_samples = redistribute_codes(code_list, snac_model)
169
 
170
  if audio_samples is None:
 
171
  return None
172
 
 
173
  return (24000, audio_samples)
174
 
175
  except Exception as e:
176
- print(f"Error generating speech: {e}")
 
 
177
  return None
178
 
179
- # Examples - reduced to save quota
180
  examples = [
181
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ Kiri ហើយខ្ញុំជា AI ដែលអាចបម្លែងអត្ថបទទៅជាសំលេង។"],
182
  ["ខ្ញុំអាចបង្ក��តសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច។"],
@@ -192,50 +253,92 @@ examples = [
192
 
193
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # Create Gradio interface
196
- with gr.Blocks(title="Khmer Text-to-Speech") as demo:
197
  gr.Markdown(f"""
 
 
198
  # 🎵 Khmer Text-to-Speech
199
  **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
200
 
201
  បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
202
 
203
  💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
 
 
204
  """)
205
 
206
  with gr.Row():
207
  with gr.Column(scale=2):
208
  text_input = gr.Textbox(
209
- label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ)",
210
- placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
211
  lines=4,
212
- max_lines=6 # Limited input size
 
213
  )
214
 
 
 
 
215
  # Advanced Settings
216
  with gr.Accordion("🔧 Advanced Settings", open=False):
217
  with gr.Row():
218
  temperature = gr.Slider(
219
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
220
- label="Temperature"
 
221
  )
222
  top_p = gr.Slider(
223
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
224
- label="Top P"
 
225
  )
226
  with gr.Row():
227
  repetition_penalty = gr.Slider(
228
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
229
- label="Repetition Penalty"
 
230
  )
231
  max_new_tokens = gr.Slider(
232
- minimum=100, maximum=800, value=800, step=50, # Reduced max
233
- label="Max Length"
 
234
  )
235
 
236
  with gr.Row():
237
- submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
238
- clear_btn = gr.Button("🗑️ Clear", size="lg")
239
 
240
  with gr.Column(scale=1):
241
  audio_output = gr.Audio(
@@ -245,14 +348,21 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
245
  interactive=False
246
  )
247
 
248
- # Examples with NO CACHE to save quota
249
  gr.Examples(
250
  examples=examples,
251
  inputs=[text_input],
252
  outputs=audio_output,
253
  fn=lambda text: generate_speech(text),
254
- cache_examples=False, # Important: no caching
255
- label="📝 Example Texts"
 
 
 
 
 
 
 
256
  )
257
 
258
  # Set up event handlers
@@ -264,18 +374,30 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
264
  )
265
 
266
  clear_btn.click(
267
- fn=lambda: (None, None),
268
  inputs=[],
269
- outputs=[text_input, audio_output]
 
 
 
 
 
 
 
 
270
  )
271
 
272
- # Launch with optimizations
273
  if __name__ == "__main__":
 
274
  demo.queue(
275
- max_size=10, # Reduced queue size
276
- default_concurrency_limit=2 # Reduced concurrent users
277
  ).launch(
 
 
278
  share=False,
279
  show_error=True,
280
- ssr_mode=False # Important for quota
 
281
  )
 
1
  import os
2
+ import time
3
+ from functools import wraps
4
  import spaces
5
  from snac import SNAC
6
  import torch
7
  import gradio as gr
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from huggingface_hub import snapshot_download, login
10
  from dotenv import load_dotenv
11
 
12
  load_dotenv()
13
 
14
+ # Rate limiting
15
+ last_request_time = {}
16
+ REQUEST_COOLDOWN = 30
17
+
18
+ def rate_limit(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ user_id = "anonymous"
22
+ current_time = time.time()
23
+
24
+ if user_id in last_request_time:
25
+ time_since_last = current_time - last_request_time[user_id]
26
+ if time_since_last < REQUEST_COOLDOWN:
27
+ remaining = int(REQUEST_COOLDOWN - time_since_last)
28
+ gr.Warning(f"Please wait {remaining} seconds before next request.")
29
+ return None
30
+
31
+ last_request_time[user_id] = current_time
32
+ return func(*args, **kwargs)
33
+ return wrapper
34
+
35
  # Get HF token from environment variables
36
  hf_token = os.getenv("HF_TOKEN")
37
 
 
38
  if hf_token:
39
  login(token=hf_token)
40
 
 
44
  print("Loading SNAC model...")
45
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
46
  snac_model = snac_model.to(device)
47
+ print("SNAC model loaded successfully")
48
 
49
  model_name = "mrrtmob/tts-khm-kore"
50
 
51
+ print(f"Downloading model files from {model_name}...")
52
  # Download only model config and safetensors with token
53
  snapshot_download(
54
  repo_id=model_name,
 
70
  "scheduler.pt"
71
  ]
72
  )
73
+ print("Model files downloaded successfully")
74
 
75
+ print("Loading main model...")
76
  # Load model and tokenizer with token
77
  model = AutoModelForCausalLM.from_pretrained(
78
  model_name,
 
81
  )
82
  model = model.to(device)
83
 
84
+ print("Loading tokenizer...")
85
  tokenizer = AutoTokenizer.from_pretrained(
86
  model_name,
87
  token=hf_token
 
160
  audio_hat = snac_model.decode(codes)
161
  return audio_hat.detach().squeeze().cpu().numpy()
162
 
163
+ # Text validation function
164
+ def validate_text(text):
165
+ """Validate and limit text length"""
166
+ MAX_LENGTH = 200
167
+ if len(text) > MAX_LENGTH:
168
+ return text[:MAX_LENGTH]
169
+ return text
170
+
171
+ # Text change handler
172
+ def on_text_change(text):
173
+ """Handle text changes and show character count"""
174
+ MAX_LENGTH = 200
175
+ current_length = len(text)
176
+
177
+ if current_length > MAX_LENGTH:
178
+ text = text[:MAX_LENGTH]
179
+ current_length = MAX_LENGTH
180
+ gr.Warning(f"Text truncated to {MAX_LENGTH} characters")
181
+
182
+ # Return the (potentially truncated) text and update info
183
+ return text, f"Characters: {current_length}/{MAX_LENGTH}"
184
+
185
+ # Main generation function with rate limiting
186
+ @rate_limit
187
+ @spaces.GPU(duration=45)
188
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800, voice="Elise", progress=gr.Progress()):
189
  if not text.strip():
190
+ gr.Warning("Please enter some text to generate speech.")
191
  return None
192
 
193
+ # Validate text length
194
+ text = validate_text(text)
195
+
196
  try:
197
  progress(0.1, "Processing text...")
198
+ print(f"Generating speech for text: {text[:50]}...")
199
+
200
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
201
 
202
  progress(0.3, "Generating speech tokens...")
 
218
  code_list = parse_output(generated_ids)
219
 
220
  if not code_list:
221
+ gr.Warning("Failed to generate valid audio codes.")
222
  return None
223
 
224
  progress(0.8, "Converting to audio...")
225
  audio_samples = redistribute_codes(code_list, snac_model)
226
 
227
  if audio_samples is None:
228
+ gr.Warning("Failed to convert codes to audio.")
229
  return None
230
 
231
+ print("Speech generation completed successfully")
232
  return (24000, audio_samples)
233
 
234
  except Exception as e:
235
+ error_msg = f"Error generating speech: {str(e)}"
236
+ print(error_msg)
237
+ gr.Error(error_msg)
238
  return None
239
 
240
+ # Examples - reduced for quota management
241
  examples = [
242
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ Kiri ហើយខ្ញុំជា AI ដែលអាចបម្លែងអត្ថបទទៅជាសំលេង។"],
243
  ["ខ្ញុំអាចបង្ក��តសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច។"],
 
253
 
254
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
255
 
256
+ # Create custom CSS
257
+ css = """
258
+ .gradio-container {
259
+ max-width: 1200px;
260
+ margin: auto;
261
+ padding-top: 1.5rem;
262
+ }
263
+ .main-header {
264
+ text-align: center;
265
+ margin-bottom: 2rem;
266
+ }
267
+ .generate-btn {
268
+ background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
269
+ border: none !important;
270
+ color: white !important;
271
+ font-weight: bold !important;
272
+ }
273
+ .clear-btn {
274
+ background: linear-gradient(45deg, #95A5A6, #BDC3C7) !important;
275
+ border: none !important;
276
+ color: white !important;
277
+ }
278
+ .char-counter {
279
+ font-size: 12px;
280
+ color: #666;
281
+ text-align: right;
282
+ margin-top: 5px;
283
+ }
284
+ """
285
+
286
  # Create Gradio interface
287
+ with gr.Blocks(title="Khmer Text-to-Speech", css=css, theme=gr.themes.Soft()) as demo:
288
  gr.Markdown(f"""
289
+ <div class="main-header">
290
+
291
  # 🎵 Khmer Text-to-Speech
292
  **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
293
 
294
  បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
295
 
296
  💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
297
+
298
+ </div>
299
  """)
300
 
301
  with gr.Row():
302
  with gr.Column(scale=2):
303
  text_input = gr.Textbox(
304
+ label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ) - Max 200 characters",
305
+ placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ... (អតិបរមា ២០០ តួអក្សរ)",
306
  lines=4,
307
+ max_lines=6,
308
+ interactive=True
309
  )
310
 
311
+ # Character counter
312
+ char_info = gr.Markdown("Characters: 0/200", elem_classes=["char-counter"])
313
+
314
  # Advanced Settings
315
  with gr.Accordion("🔧 Advanced Settings", open=False):
316
  with gr.Row():
317
  temperature = gr.Slider(
318
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
319
+ label="Temperature",
320
+ info="Higher values create more expressive speech"
321
  )
322
  top_p = gr.Slider(
323
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
324
+ label="Top P",
325
+ info="Nucleus sampling threshold"
326
  )
327
  with gr.Row():
328
  repetition_penalty = gr.Slider(
329
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
330
+ label="Repetition Penalty",
331
+ info="Higher values discourage repetitive patterns"
332
  )
333
  max_new_tokens = gr.Slider(
334
+ minimum=100, maximum=600, value=600, step=50,
335
+ label="Max Length",
336
+ info="Maximum length of generated audio"
337
  )
338
 
339
  with gr.Row():
340
+ submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg", elem_classes=["generate-btn"])
341
+ clear_btn = gr.Button("🗑️ Clear", size="lg", elem_classes=["clear-btn"])
342
 
343
  with gr.Column(scale=1):
344
  audio_output = gr.Audio(
 
348
  interactive=False
349
  )
350
 
351
+ # Set up examples (NO CACHE to save quota)
352
  gr.Examples(
353
  examples=examples,
354
  inputs=[text_input],
355
  outputs=audio_output,
356
  fn=lambda text: generate_speech(text),
357
+ cache_examples=False,
358
+ label="📝 Example Texts (អត្ថបទគំរូ)"
359
+ )
360
+
361
+ # Text change event handler
362
+ text_input.change(
363
+ fn=on_text_change,
364
+ inputs=[text_input],
365
+ outputs=[text_input, char_info]
366
  )
367
 
368
  # Set up event handlers
 
374
  )
375
 
376
  clear_btn.click(
377
+ fn=lambda: ("", "Characters: 0/200", None),
378
  inputs=[],
379
+ outputs=[text_input, char_info, audio_output]
380
+ )
381
+
382
+ # Add keyboard shortcut
383
+ text_input.submit(
384
+ fn=generate_speech,
385
+ inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
386
+ outputs=audio_output,
387
+ show_progress=True
388
  )
389
 
390
+ # Launch with embed-friendly optimizations
391
  if __name__ == "__main__":
392
+ print("Starting Gradio interface...")
393
  demo.queue(
394
+ max_size=3, # Small queue for embeds
395
+ default_concurrency_limit=1 # One user at a time
396
  ).launch(
397
+ server_name="0.0.0.0",
398
+ server_port=7860,
399
  share=False,
400
  show_error=True,
401
+ ssr_mode=False,
402
+ auth_message="Login to HuggingFace recommended for better GPU quota"
403
  )