Spaces:
Running
on
Zero
Running
on
Zero
tts-khm-1
Browse files
app.py
CHANGED
@@ -6,12 +6,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
6 |
from huggingface_hub import snapshot_download
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv()
|
|
|
9 |
# Check if CUDA is available
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
print("Loading SNAC model...")
|
12 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
13 |
snac_model = snac_model.to(device)
|
|
|
14 |
model_name = "mrrtmob/tts-khm-1"
|
|
|
15 |
# Download only model config and safetensors
|
16 |
snapshot_download(
|
17 |
repo_id=model_name,
|
@@ -33,10 +36,12 @@ snapshot_download(
|
|
33 |
"tokenizer.*"
|
34 |
]
|
35 |
)
|
|
|
36 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
37 |
model.to(device)
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
print(f"Khmer TTS model loaded to {device}")
|
|
|
40 |
# Process text prompt
|
41 |
def process_prompt(prompt, voice, tokenizer, device):
|
42 |
prompt = f"{voice}: {prompt}"
|
@@ -51,6 +56,7 @@ def process_prompt(prompt, voice, tokenizer, device):
|
|
51 |
attention_mask = torch.ones_like(modified_input_ids)
|
52 |
|
53 |
return modified_input_ids.to(device), attention_mask.to(device)
|
|
|
54 |
# Parse output tokens to audio
|
55 |
def parse_output(generated_ids):
|
56 |
token_to_find = 128257
|
@@ -62,10 +68,12 @@ def parse_output(generated_ids):
|
|
62 |
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
63 |
else:
|
64 |
cropped_tensor = generated_ids
|
|
|
65 |
processed_rows = []
|
66 |
for row in cropped_tensor:
|
67 |
masked_row = row[row != token_to_remove]
|
68 |
processed_rows.append(masked_row)
|
|
|
69 |
code_lists = []
|
70 |
for row in processed_rows:
|
71 |
row_length = row.size(0)
|
@@ -75,6 +83,7 @@ def parse_output(generated_ids):
|
|
75 |
code_lists.append(trimmed_row)
|
76 |
|
77 |
return code_lists[0] # Return just the first one for single sample
|
|
|
78 |
# Redistribute codes for audio generation
|
79 |
def redistribute_codes(code_list, snac_model):
|
80 |
device = next(snac_model.parameters()).device # Get the device of SNAC model
|
@@ -100,6 +109,7 @@ def redistribute_codes(code_list, snac_model):
|
|
100 |
|
101 |
audio_hat = snac_model.decode(codes)
|
102 |
return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
|
|
|
103 |
# Main generation function
|
104 |
@spaces.GPU()
|
105 |
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
|
@@ -134,6 +144,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
|
|
134 |
except Exception as e:
|
135 |
print(f"Error generating speech: {e}")
|
136 |
return None
|
|
|
137 |
# Examples for the UI - Khmer text examples
|
138 |
examples = [
|
139 |
["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
|
@@ -145,10 +156,13 @@ examples = [
|
|
145 |
["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
|
146 |
["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
|
147 |
]
|
|
|
148 |
# Available voices
|
149 |
VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
|
|
|
150 |
# Available Emotive Tags
|
151 |
EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
|
|
|
152 |
# Create Gradio interface
|
153 |
with gr.Blocks(title="Khmer Text-to-Speech") as demo:
|
154 |
gr.Markdown(f"""
|
@@ -163,6 +177,7 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
|
|
163 |
- អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
|
164 |
- Increasing `repetition_penalty` and `temperature` makes the model speak faster
|
165 |
""")
|
|
|
166 |
with gr.Row():
|
167 |
with gr.Column(scale=3):
|
168 |
text_input = gr.Textbox(
|
@@ -226,6 +241,7 @@ with gr.Blocks(title="Khmer Text-to-Speech") as demo:
|
|
226 |
inputs=[],
|
227 |
outputs=[text_input, audio_output]
|
228 |
)
|
|
|
229 |
# Launch the app
|
230 |
if __name__ == "__main__":
|
231 |
-
demo.queue().launch(share=False
|
|
|
6 |
from huggingface_hub import snapshot_download
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv()
|
9 |
+
|
10 |
# Check if CUDA is available
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
print("Loading SNAC model...")
|
13 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
14 |
snac_model = snac_model.to(device)
|
15 |
+
|
16 |
model_name = "mrrtmob/tts-khm-1"
|
17 |
+
|
18 |
# Download only model config and safetensors
|
19 |
snapshot_download(
|
20 |
repo_id=model_name,
|
|
|
36 |
"tokenizer.*"
|
37 |
]
|
38 |
)
|
39 |
+
|
40 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
41 |
model.to(device)
|
42 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
43 |
print(f"Khmer TTS model loaded to {device}")
|
44 |
+
|
45 |
# Process text prompt
|
46 |
def process_prompt(prompt, voice, tokenizer, device):
|
47 |
prompt = f"{voice}: {prompt}"
|
|
|
56 |
attention_mask = torch.ones_like(modified_input_ids)
|
57 |
|
58 |
return modified_input_ids.to(device), attention_mask.to(device)
|
59 |
+
|
60 |
# Parse output tokens to audio
|
61 |
def parse_output(generated_ids):
|
62 |
token_to_find = 128257
|
|
|
68 |
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
69 |
else:
|
70 |
cropped_tensor = generated_ids
|
71 |
+
|
72 |
processed_rows = []
|
73 |
for row in cropped_tensor:
|
74 |
masked_row = row[row != token_to_remove]
|
75 |
processed_rows.append(masked_row)
|
76 |
+
|
77 |
code_lists = []
|
78 |
for row in processed_rows:
|
79 |
row_length = row.size(0)
|
|
|
83 |
code_lists.append(trimmed_row)
|
84 |
|
85 |
return code_lists[0] # Return just the first one for single sample
|
86 |
+
|
87 |
# Redistribute codes for audio generation
|
88 |
def redistribute_codes(code_list, snac_model):
|
89 |
device = next(snac_model.parameters()).device # Get the device of SNAC model
|
|
|
109 |
|
110 |
audio_hat = snac_model.decode(codes)
|
111 |
return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
|
112 |
+
|
113 |
# Main generation function
|
114 |
@spaces.GPU()
|
115 |
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
|
|
|
144 |
except Exception as e:
|
145 |
print(f"Error generating speech: {e}")
|
146 |
return None
|
147 |
+
|
148 |
# Examples for the UI - Khmer text examples
|
149 |
examples = [
|
150 |
["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។", "tara", 0.6, 0.95, 1.1, 1200],
|
|
|
156 |
["តើអ្នកបានឮរឿងកំប្លែងនេះយ៉ាងណា? <laugh> ខ្ញុំមិនអាចបញ្ឈប់ការសើចបាននោះទេ។", "zac", 0.7, 0.95, 1.1, 1200],
|
157 |
["បន្ទាប់ពីរត់ម៉ារ៉ាតុងរួច ខ្ញុំហត់ណាស់ <yawn> ហើយត្រូវការសម្រាក។", "zoe", 0.6, 0.95, 1.1, 1200]
|
158 |
]
|
159 |
+
|
160 |
# Available voices
|
161 |
VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
|
162 |
+
|
163 |
# Available Emotive Tags
|
164 |
EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
|
165 |
+
|
166 |
# Create Gradio interface
|
167 |
with gr.Blocks(title="Khmer Text-to-Speech") as demo:
|
168 |
gr.Markdown(f"""
|
|
|
177 |
- អត្ថបទវែងជាទូទៅមានលទ្ធផលល្អជាងអត្ថបទខ្លី
|
178 |
- Increasing `repetition_penalty` and `temperature` makes the model speak faster
|
179 |
""")
|
180 |
+
|
181 |
with gr.Row():
|
182 |
with gr.Column(scale=3):
|
183 |
text_input = gr.Textbox(
|
|
|
241 |
inputs=[],
|
242 |
outputs=[text_input, audio_output]
|
243 |
)
|
244 |
+
|
245 |
# Launch the app
|
246 |
if __name__ == "__main__":
|
247 |
+
demo.queue().launch(share=False)
|