mrrtmob commited on
Commit
8a44aae
·
1 Parent(s): 8f1133d
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -6,12 +6,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  load_dotenv()
 
9
  # Check if CUDA is available
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print("Loading SNAC model...")
12
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
  snac_model = snac_model.to(device)
 
14
  model_name = "mrrtmob/tts-khm-1"
 
15
  # Download only model config and safetensors
16
  snapshot_download(
17
  repo_id=model_name,
@@ -33,10 +36,12 @@ snapshot_download(
33
  "tokenizer.*"
34
  ]
35
  )
 
36
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
  model.to(device)
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  print(f"Khmer TTS model loaded to {device}")
 
40
  # Process text prompt
41
  def process_prompt(prompt, voice, tokenizer, device):
42
  prompt = f"{voice}: {prompt}"
@@ -51,6 +56,7 @@ def process_prompt(prompt, voice, tokenizer, device):
51
  attention_mask = torch.ones_like(modified_input_ids)
52
 
53
  return modified_input_ids.to(device), attention_mask.to(device)
 
54
  # Parse output tokens to audio
55
  def parse_output(generated_ids):
56
  token_to_find = 128257
@@ -62,10 +68,12 @@ def parse_output(generated_ids):
62
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
  else:
64
  cropped_tensor = generated_ids
 
65
  processed_rows = []
66
  for row in cropped_tensor:
67
  masked_row = row[row != token_to_remove]
68
  processed_rows.append(masked_row)
 
69
  code_lists = []
70
  for row in processed_rows:
71
  row_length = row.size(0)
@@ -75,6 +83,7 @@ def parse_output(generated_ids):
75
  code_lists.append(trimmed_row)
76
 
77
  return code_lists[0] # Return just the first one for single sample
 
78
  # Redistribute codes for audio generation
79
  def redistribute_codes(code_list, snac_model):
80
  device = next(snac_model.parameters()).device # Get the device of SNAC model
@@ -100,6 +109,7 @@ def redistribute_codes(code_list, snac_model):
100
 
101
  audio_hat = snac_model.decode(codes)
102
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
 
103
  # Main generation function
104
  @spaces.GPU()
105
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
@@ -134,6 +144,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
134
  except Exception as e:
135
  print(f"Error generating speech: {e}")
136
  return None
 
137
  # Examples for the UI - Khmer text examples
138
  examples = [
139
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
@@ -145,10 +156,13 @@ examples = [
145
  ["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
146
  ["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
147
  ]
 
148
  # Available voices
149
  VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
 
150
  # Available Emotive Tags
151
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
 
152
  # Create Gradio interface
153
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
154
  gr.Markdown(f"""
@@ -163,6 +177,7 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
163
  - អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
164
  - Increasing `repetition_penalty` and `temperature` makes the model speak faster
165
  """)
 
166
  with gr.Row():
167
  with gr.Column(scale=3):
168
  text_input = gr.Textbox(
@@ -226,6 +241,7 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
226
  inputs=[],
227
  outputs=[text_input, audio_output]
228
  )
 
229
  # Launch the app
230
  if __name__ == "__main__":
231
- demo.queue().launch(share=False, ssr_mode=False)
 
6
  from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
+
10
  # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print("Loading SNAC model...")
13
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
14
  snac_model = snac_model.to(device)
15
+
16
  model_name = "mrrtmob/tts-khm-1"
17
+
18
  # Download only model config and safetensors
19
  snapshot_download(
20
  repo_id=model_name,
 
36
  "tokenizer.*"
37
  ]
38
  )
39
+
40
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
41
  model.to(device)
42
  tokenizer = AutoTokenizer.from_pretrained(model_name)
43
  print(f"Khmer TTS model loaded to {device}")
44
+
45
  # Process text prompt
46
  def process_prompt(prompt, voice, tokenizer, device):
47
  prompt = f"{voice}: {prompt}"
 
56
  attention_mask = torch.ones_like(modified_input_ids)
57
 
58
  return modified_input_ids.to(device), attention_mask.to(device)
59
+
60
  # Parse output tokens to audio
61
  def parse_output(generated_ids):
62
  token_to_find = 128257
 
68
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
69
  else:
70
  cropped_tensor = generated_ids
71
+
72
  processed_rows = []
73
  for row in cropped_tensor:
74
  masked_row = row[row != token_to_remove]
75
  processed_rows.append(masked_row)
76
+
77
  code_lists = []
78
  for row in processed_rows:
79
  row_length = row.size(0)
 
83
  code_lists.append(trimmed_row)
84
 
85
  return code_lists[0] # Return just the first one for single sample
86
+
87
  # Redistribute codes for audio generation
88
  def redistribute_codes(code_list, snac_model):
89
  device = next(snac_model.parameters()).device # Get the device of SNAC model
 
109
 
110
  audio_hat = snac_model.decode(codes)
111
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
112
+
113
  # Main generation function
114
  @spaces.GPU()
115
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
 
144
  except Exception as e:
145
  print(f"Error generating speech: {e}")
146
  return None
147
+
148
  # Examples for the UI - Khmer text examples
149
  examples = [
150
  ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
 
156
  ["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
157
  ["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
158
  ]
159
+
160
  # Available voices
161
  VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
162
+
163
  # Available Emotive Tags
164
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
165
+
166
  # Create Gradio interface
167
  with gr.Blocks(title="Khmer Text-to-Speech") as demo:
168
  gr.Markdown(f"""
 
177
  - អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
178
  - Increasing `repetition_penalty` and `temperature` makes the model speak faster
179
  """)
180
+
181
  with gr.Row():
182
  with gr.Column(scale=3):
183
  text_input = gr.Textbox(
 
241
  inputs=[],
242
  outputs=[text_input, audio_output]
243
  )
244
+
245
  # Launch the app
246
  if __name__ == "__main__":
247
+ demo.queue().launch(share=False)