mrrtmob commited on
Commit
2938eff
Β·
1 Parent(s): 1deb3e2

Add Hugging Face authentication and improve speech generation settings

Browse files
Files changed (1) hide show
  1. app.py +62 -70
app.py CHANGED
@@ -3,7 +3,7 @@ from snac import SNAC
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  import os
9
  import re
@@ -11,9 +11,28 @@ import numpy as np
11
 
12
  load_dotenv()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Check if CUDA is available
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Using device: {device}")
 
17
 
18
  # Global variables to store models
19
  snac_model = None
@@ -29,7 +48,6 @@ def load_models():
29
 
30
  model_name = "mrrtmob/tts-khm-4"
31
 
32
- # Download specific files
33
  print("Downloading model files...")
34
  snapshot_download(
35
  repo_id=model_name,
@@ -52,7 +70,6 @@ def load_models():
52
  )
53
 
54
  print("Loading main model...")
55
- # Simplified model loading without device_map
56
  if device == "cuda":
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_name,
@@ -78,24 +95,19 @@ def load_models():
78
  load_models()
79
 
80
  def split_text_by_punctuation(text, max_chars=200):
81
- """Split text by punctuation marks, keeping sentences together when possible"""
82
- # Khmer and common punctuation
83
  sentence_endings = r'[αŸ”!?]'
84
  clause_separators = r'[,;:]'
85
 
86
- # First try to split by sentence endings
87
  sentences = re.split(f'({sentence_endings})', text)
88
 
89
- # Recombine sentences with their punctuation
90
  combined_sentences = []
91
  for i in range(0, len(sentences), 2):
92
  sentence = sentences[i]
93
  if i + 1 < len(sentences):
94
- sentence += sentences[i + 1] # Add the punctuation back
95
  if sentence.strip():
96
  combined_sentences.append(sentence.strip())
97
 
98
- # If no sentence endings found, split by clauses
99
  if len(combined_sentences) <= 1:
100
  parts = re.split(f'({clause_separators})', text)
101
  combined_sentences = []
@@ -106,13 +118,11 @@ def split_text_by_punctuation(text, max_chars=200):
106
  if part.strip():
107
  combined_sentences.append(part.strip())
108
 
109
- # Further split if sentences are too long
110
  final_chunks = []
111
  for sentence in combined_sentences:
112
  if len(sentence) <= max_chars:
113
  final_chunks.append(sentence)
114
  else:
115
- # Split long sentences by words
116
  words = sentence.split()
117
  current_chunk = ""
118
 
@@ -131,10 +141,8 @@ def split_text_by_punctuation(text, max_chars=200):
131
  return [chunk for chunk in final_chunks if chunk.strip()]
132
 
133
  def split_text_by_tokens(text, max_tokens=150):
134
- """Split text by token count"""
135
  global tokenizer
136
 
137
- # Tokenize the entire text first
138
  tokens = tokenizer.encode(text)
139
 
140
  if len(tokens) <= max_tokens:
@@ -197,7 +205,7 @@ def parse_output(generated_ids):
197
 
198
  def redistribute_codes(code_list, snac_model):
199
  if not code_list or len(code_list) < 7:
200
- return np.zeros(12000) # 0.5 seconds of silence
201
 
202
  device = next(snac_model.parameters()).device
203
  layer_1 = []
@@ -227,8 +235,27 @@ def redistribute_codes(code_list, snac_model):
227
  print(f"Error in redistribute_codes: {e}")
228
  return np.zeros(12000)
229
 
230
- @spaces.GPU(duration=120)
231
- def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800, voice="Elise"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  """Generate speech for a single chunk"""
233
  global model, tokenizer, snac_model
234
 
@@ -265,30 +292,8 @@ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_pe
265
  print(f"Error generating speech chunk: {e}")
266
  return np.array([])
267
 
268
- def combine_audio_chunks(audio_chunks, pause_duration=0.3):
269
- """Combine audio chunks with pauses between them"""
270
- if not audio_chunks:
271
- return np.array([])
272
-
273
- # Create pause (silence)
274
- pause_samples = int(24000 * pause_duration) # 24kHz sample rate
275
- pause = np.zeros(pause_samples)
276
-
277
- combined_audio = []
278
- for i, chunk in enumerate(audio_chunks):
279
- if len(chunk) > 0:
280
- combined_audio.append(chunk)
281
- # Add pause between chunks (except after the last chunk)
282
- if i < len(audio_chunks) - 1:
283
- combined_audio.append(pause)
284
-
285
- if combined_audio:
286
- return np.concatenate(combined_audio)
287
- else:
288
- return np.array([])
289
-
290
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=800,
291
- voice="Elise", split_method="punctuation", max_chars=200, max_tokens=150,
292
  pause_duration=0.3, progress=gr.Progress()):
293
  """Main function to generate speech with text splitting"""
294
 
@@ -296,14 +301,13 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
296
  return None
297
 
298
  try:
299
- # Split text based on selected method
300
  progress(0.05, "Splitting text...")
301
 
302
  if split_method == "punctuation":
303
  text_chunks = split_text_by_punctuation(text, max_chars)
304
  elif split_method == "tokens":
305
  text_chunks = split_text_by_tokens(text, max_tokens)
306
- else: # "none"
307
  text_chunks = [text]
308
 
309
  progress(0.1, f"Processing {len(text_chunks)} chunks...")
@@ -311,7 +315,6 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
311
  for i, chunk in enumerate(text_chunks):
312
  print(f"Chunk {i+1}: {chunk[:50]}...")
313
 
314
- # Generate audio for each chunk
315
  audio_chunks = []
316
  for i, chunk in enumerate(text_chunks):
317
  progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
@@ -327,7 +330,6 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
327
  if not audio_chunks:
328
  return None
329
 
330
- # Combine all audio chunks
331
  progress(0.9, "Combining audio chunks...")
332
  final_audio = combine_audio_chunks(audio_chunks, pause_duration)
333
 
@@ -342,19 +344,20 @@ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, m
342
  traceback.print_exc()
343
  return None
344
 
345
- # Examples
346
  examples = [
347
- ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ Kiri αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
348
  ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
349
  ]
350
 
351
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
352
 
353
- # Create Gradio interface
354
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
355
  gr.Markdown(f"""
356
  # 🎡 Khmer Text-to-Speech
357
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
 
 
358
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
359
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
360
  ✨ **New**: Supports long text with automatic splitting!
@@ -366,7 +369,6 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
366
  lines=6
367
  )
368
 
369
- # Text splitting options
370
  with gr.Accordion("πŸ“ Text Splitting Options", open=True):
371
  split_method = gr.Radio(
372
  choices=[
@@ -375,51 +377,42 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
375
  ("No splitting", "none")
376
  ],
377
  value="punctuation",
378
- label="Text splitting method",
379
- info="For long texts, splitting helps avoid the 15s limit"
380
  )
381
 
382
  with gr.Row():
383
  max_chars = gr.Slider(
384
- minimum=50, maximum=500, value=200, step=25,
385
- label="Max characters per chunk (punctuation mode)",
386
- info="Shorter chunks = more natural breaks but more processing time"
387
  )
388
  max_tokens = gr.Slider(
389
- minimum=50, maximum=300, value=150, step=25,
390
- label="Max tokens per chunk (token mode)",
391
- info="Controls chunk size based on model tokenization"
392
  )
393
 
394
  pause_duration = gr.Slider(
395
  minimum=0.0, maximum=1.0, value=0.3, step=0.1,
396
- label="Pause between chunks (seconds)",
397
- info="Silence duration between text chunks"
398
  )
399
 
400
- # Advanced Settings
401
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
402
  with gr.Row():
403
  temperature = gr.Slider(
404
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
405
- label="Temperature",
406
- info="Higher values create more expressive speech"
407
  )
408
  top_p = gr.Slider(
409
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
410
- label="Top P",
411
- info="Nucleus sampling threshold"
412
  )
413
  with gr.Row():
414
  repetition_penalty = gr.Slider(
415
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
416
- label="Repetition Penalty",
417
- info="Higher values discourage repetitive patterns"
418
  )
419
  max_new_tokens = gr.Slider(
420
- minimum=100, maximum=1200, value=800, step=100,
421
- label="Max tokens per chunk",
422
- info="Lower values for shorter, more reliable generation"
423
  )
424
 
425
  with gr.Row():
@@ -453,9 +446,8 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
453
  outputs=[text_input, audio_output]
454
  )
455
 
456
- # Launch the app
457
  if __name__ == "__main__":
458
- demo.queue(max_size=10).launch(
459
  share=False,
460
  server_name="0.0.0.0",
461
  server_port=7860
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from huggingface_hub import snapshot_download, login
7
  from dotenv import load_dotenv
8
  import os
9
  import re
 
11
 
12
  load_dotenv()
13
 
14
+ # Setup Hugging Face authentication
15
+ def setup_auth():
16
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
17
+ if hf_token:
18
+ try:
19
+ login(token=hf_token, add_to_git_credential=False)
20
+ print("βœ… Successfully logged in to Hugging Face")
21
+ return True
22
+ except Exception as e:
23
+ print(f"⚠️ Failed to login to Hugging Face: {e}")
24
+ return False
25
+ else:
26
+ print("⚠️ No HF token found. Running as anonymous user.")
27
+ return False
28
+
29
+ # Setup authentication before anything else
30
+ auth_success = setup_auth()
31
+
32
  # Check if CUDA is available
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print(f"Using device: {device}")
35
+ print(f"Authentication status: {'βœ… Logged in' if auth_success else '❌ Anonymous'}")
36
 
37
  # Global variables to store models
38
  snac_model = None
 
48
 
49
  model_name = "mrrtmob/tts-khm-4"
50
 
 
51
  print("Downloading model files...")
52
  snapshot_download(
53
  repo_id=model_name,
 
70
  )
71
 
72
  print("Loading main model...")
 
73
  if device == "cuda":
74
  model = AutoModelForCausalLM.from_pretrained(
75
  model_name,
 
95
  load_models()
96
 
97
  def split_text_by_punctuation(text, max_chars=200):
 
 
98
  sentence_endings = r'[αŸ”!?]'
99
  clause_separators = r'[,;:]'
100
 
 
101
  sentences = re.split(f'({sentence_endings})', text)
102
 
 
103
  combined_sentences = []
104
  for i in range(0, len(sentences), 2):
105
  sentence = sentences[i]
106
  if i + 1 < len(sentences):
107
+ sentence += sentences[i + 1]
108
  if sentence.strip():
109
  combined_sentences.append(sentence.strip())
110
 
 
111
  if len(combined_sentences) <= 1:
112
  parts = re.split(f'({clause_separators})', text)
113
  combined_sentences = []
 
118
  if part.strip():
119
  combined_sentences.append(part.strip())
120
 
 
121
  final_chunks = []
122
  for sentence in combined_sentences:
123
  if len(sentence) <= max_chars:
124
  final_chunks.append(sentence)
125
  else:
 
126
  words = sentence.split()
127
  current_chunk = ""
128
 
 
141
  return [chunk for chunk in final_chunks if chunk.strip()]
142
 
143
  def split_text_by_tokens(text, max_tokens=150):
 
144
  global tokenizer
145
 
 
146
  tokens = tokenizer.encode(text)
147
 
148
  if len(tokens) <= max_tokens:
 
205
 
206
  def redistribute_codes(code_list, snac_model):
207
  if not code_list or len(code_list) < 7:
208
+ return np.zeros(12000)
209
 
210
  device = next(snac_model.parameters()).device
211
  layer_1 = []
 
235
  print(f"Error in redistribute_codes: {e}")
236
  return np.zeros(12000)
237
 
238
+ def combine_audio_chunks(audio_chunks, pause_duration=0.3):
239
+ if not audio_chunks:
240
+ return np.array([])
241
+
242
+ pause_samples = int(24000 * pause_duration)
243
+ pause = np.zeros(pause_samples)
244
+
245
+ combined_audio = []
246
+ for i, chunk in enumerate(audio_chunks):
247
+ if len(chunk) > 0:
248
+ combined_audio.append(chunk)
249
+ if i < len(audio_chunks) - 1:
250
+ combined_audio.append(pause)
251
+
252
+ if combined_audio:
253
+ return np.concatenate(combined_audio)
254
+ else:
255
+ return np.array([])
256
+
257
+ @spaces.GPU(duration=60) # Reduced duration to be more conservative
258
+ def generate_speech_chunk(text_chunk, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600, voice="Elise"):
259
  """Generate speech for a single chunk"""
260
  global model, tokenizer, snac_model
261
 
 
292
  print(f"Error generating speech chunk: {e}")
293
  return np.array([])
294
 
295
+ def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=600,
296
+ voice="Elise", split_method="punctuation", max_chars=150, max_tokens=100,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  pause_duration=0.3, progress=gr.Progress()):
298
  """Main function to generate speech with text splitting"""
299
 
 
301
  return None
302
 
303
  try:
 
304
  progress(0.05, "Splitting text...")
305
 
306
  if split_method == "punctuation":
307
  text_chunks = split_text_by_punctuation(text, max_chars)
308
  elif split_method == "tokens":
309
  text_chunks = split_text_by_tokens(text, max_tokens)
310
+ else:
311
  text_chunks = [text]
312
 
313
  progress(0.1, f"Processing {len(text_chunks)} chunks...")
 
315
  for i, chunk in enumerate(text_chunks):
316
  print(f"Chunk {i+1}: {chunk[:50]}...")
317
 
 
318
  audio_chunks = []
319
  for i, chunk in enumerate(text_chunks):
320
  progress(0.1 + 0.7 * (i / len(text_chunks)), f"Generating chunk {i+1}/{len(text_chunks)}...")
 
330
  if not audio_chunks:
331
  return None
332
 
 
333
  progress(0.9, "Combining audio chunks...")
334
  final_audio = combine_audio_chunks(audio_chunks, pause_duration)
335
 
 
344
  traceback.print_exc()
345
  return None
346
 
347
+ # [Rest of your Gradio interface code remains the same]
348
  examples = [
349
+ ["αž‡αŸ†αžšαžΆαž”αžŸαž½αžš αžαŸ’αž‰αž»αŸ†αžˆαŸ’αž˜αŸ„αŸ‡ αžαžΆαžšαžΆαŸ” αžαŸ’αž‰αž»αŸ†αž‚αžΊαž‡αžΆαž˜αŸ‰αžΌαžŠαŸ‚αž›αž•αž›αž·αžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”"],
350
  ["αžαŸ’αž‰αž»αŸ†αž’αžΆαž…αž”αž„αŸ’αž€αžΎαžαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αž•αŸ’αžŸαŸαž„αŸ— αžŠαžΌαž…αž‡αžΆ <laugh> αžŸαžΎαž… ឬ <sigh> αžαž”αŸ‹αžŠαž„αŸ’αž αžΎαž˜αŸ”"],
351
  ]
352
 
353
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
354
 
 
355
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
356
  gr.Markdown(f"""
357
  # 🎡 Khmer Text-to-Speech
358
  **αž˜αŸ‰αžΌαžŠαŸ‚αž›αž”αž˜αŸ’αž›αŸ‚αž„αž’αžαŸ’αžαž”αž‘αž‡αžΆαžŸαŸ†αž›αŸαž„**
359
+ Authentication: {'βœ… Pro Account' if auth_success else '❌ Anonymous (Limited)'}
360
+
361
  αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαžšαž”αžŸαŸ‹αž’αŸ’αž“αž€ αž αžΎαž™αžŸαŸ’αžαžΆαž”αŸ‹αž€αžΆαžšαž”αž˜αŸ’αž›αŸ‚αž„αž‘αŸ…αž‡αžΆαžŸαŸ†αž›αŸαž„αž“αž·αž™αžΆαž™αŸ”
362
  πŸ’‘ **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
363
  ✨ **New**: Supports long text with automatic splitting!
 
369
  lines=6
370
  )
371
 
 
372
  with gr.Accordion("πŸ“ Text Splitting Options", open=True):
373
  split_method = gr.Radio(
374
  choices=[
 
377
  ("No splitting", "none")
378
  ],
379
  value="punctuation",
380
+ label="Text splitting method"
 
381
  )
382
 
383
  with gr.Row():
384
  max_chars = gr.Slider(
385
+ minimum=50, maximum=300, value=150, step=25,
386
+ label="Max characters per chunk"
 
387
  )
388
  max_tokens = gr.Slider(
389
+ minimum=50, maximum=200, value=100, step=25,
390
+ label="Max tokens per chunk"
 
391
  )
392
 
393
  pause_duration = gr.Slider(
394
  minimum=0.0, maximum=1.0, value=0.3, step=0.1,
395
+ label="Pause between chunks (seconds)"
 
396
  )
397
 
 
398
  with gr.Accordion("πŸ”§ Advanced Settings", open=False):
399
  with gr.Row():
400
  temperature = gr.Slider(
401
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
402
+ label="Temperature"
 
403
  )
404
  top_p = gr.Slider(
405
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
406
+ label="Top P"
 
407
  )
408
  with gr.Row():
409
  repetition_penalty = gr.Slider(
410
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
411
+ label="Repetition Penalty"
 
412
  )
413
  max_new_tokens = gr.Slider(
414
+ minimum=100, maximum=800, value=600, step=100,
415
+ label="Max tokens per chunk"
 
416
  )
417
 
418
  with gr.Row():
 
446
  outputs=[text_input, audio_output]
447
  )
448
 
 
449
  if __name__ == "__main__":
450
+ demo.queue(max_size=5).launch(
451
  share=False,
452
  server_name="0.0.0.0",
453
  server_port=7860