TDN-M commited on
Commit
d9a3d58
·
verified ·
1 Parent(s): afbe477

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -40
app.py CHANGED
@@ -7,6 +7,7 @@ import uuid
7
  from io import StringIO
8
 
9
  import gradio as gr
 
10
  import torch
11
  import torchaudio
12
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
@@ -14,12 +15,14 @@ from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
 
17
- # Initialize Hugging Face API
 
 
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  api = HfApi(token=HF_TOKEN)
20
 
21
- # Download model files if not already downloaded
22
- print("Downloading viXTTS model files if not already present...")
23
  checkpoint_dir = "model/"
24
  repo_id = "capleaf/viXTTS"
25
  use_deepspeed = False
@@ -40,7 +43,6 @@ if not all(file in files_in_dir for file in required_files):
40
  local_dir=checkpoint_dir,
41
  )
42
 
43
- # Load model configuration and initialize model
44
  xtts_config = os.path.join(checkpoint_dir, "config.json")
45
  config = XttsConfig()
46
  config.load_json(xtts_config)
@@ -51,9 +53,8 @@ MODEL.load_checkpoint(
51
  if torch.cuda.is_available():
52
  MODEL.cuda()
53
 
54
- # Supported languages
55
  supported_languages = config.languages
56
- if "vi" not in supported_languages:
57
  supported_languages.append("vi")
58
 
59
 
@@ -74,6 +75,7 @@ def normalize_vietnamese_text(text):
74
 
75
 
76
  def calculate_keep_len(text, lang):
 
77
  if lang in ["ja", "zh-cn"]:
78
  return -1
79
 
@@ -87,39 +89,63 @@ def calculate_keep_len(text, lang):
87
  return -1
88
 
89
 
90
- def predict(prompt, language, audio_file_pth, normalize_text=True):
 
 
 
 
 
 
91
  if language not in supported_languages:
92
  metrics_text = gr.Warning(
93
- f"Language {language} is not supported. Please choose from the dropdown."
94
  )
95
- return None, metrics_text
 
 
 
96
 
97
  if len(prompt) < 2:
98
- metrics_text = gr.Warning("Please provide a longer prompt text.")
99
- return None, metrics_text
 
 
 
 
 
 
 
 
 
100
 
101
  try:
102
  metrics_text = ""
103
  t_latent = time.time()
104
 
105
  try:
106
- gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
107
- audio_path=audio_file_pth,
 
 
 
108
  gpt_cond_len=30,
109
  gpt_cond_chunk_len=4,
110
  max_ref_length=60,
111
  )
 
112
  except Exception as e:
113
- print("Speaker encoding error:", str(e))
114
- metrics_text = gr.Warning("Error with reference audio.")
115
- return None, metrics_text
 
 
116
 
117
- prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2", prompt)
118
 
119
  if normalize_text and language == "vi":
120
  prompt = normalize_vietnamese_text(prompt)
121
 
122
- print("Generating new audio...")
123
  t0 = time.time()
124
  out = MODEL.inference(
125
  prompt,
@@ -131,68 +157,164 @@ def predict(prompt, language, audio_file_pth, normalize_text=True):
131
  enable_text_splitting=True,
132
  )
133
  inference_time = time.time() - t0
134
- metrics_text += f"Time to generate audio: {round(inference_time * 1000)} ms\n"
 
 
 
135
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
 
136
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
137
 
 
138
  keep_len = calculate_keep_len(prompt, language)
139
  out["wav"] = out["wav"][:keep_len]
140
 
141
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
142
 
143
  except RuntimeError as e:
144
- print("RuntimeError:", str(e))
145
- metrics_text = gr.Warning("An error occurred during processing.")
146
- return None, metrics_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- return "output.wav", metrics_text
 
 
 
 
 
 
 
 
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- title = "Phòng Thu VMC"
152
 
153
  with gr.Blocks(analytics_enabled=False) as demo:
154
  with gr.Row():
155
  with gr.Column():
156
- gr.Markdown("## VMC LAB")
 
 
 
 
 
157
  with gr.Column():
 
158
  pass
159
 
160
  with gr.Row():
161
  with gr.Column():
162
  input_text_gr = gr.Textbox(
163
- label="Text Prompt",
164
- info="One or two sentences at a time is better. Up to 200 text characters.",
165
- value="Xin chào, hãy nhập nội dung cần thu âm vào đây",
166
  )
167
  language_gr = gr.Dropdown(
168
- label="Language",
169
- info="Select an output language for the synthesised speech",
170
- choices=supported_languages,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  value="vi",
172
  )
173
  normalize_text = gr.Checkbox(
174
- label="Normalize Vietnamese Text",
175
- info="Normalize Vietnamese Text",
176
  value=True,
177
  )
178
  ref_gr = gr.Audio(
179
- label="Reference Audio",
180
- info="Click on the ✎ button to upload your own target speaker audio",
181
  type="filepath",
182
- value="PG 2-2.wav",
 
 
 
 
 
 
183
  )
184
- tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
185
 
186
  with gr.Column():
187
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
188
- out_text_gr = gr.Textbox(label="Metrics")
189
 
190
  tts_button.click(
191
  predict,
192
- [input_text_gr, language_gr, ref_gr, normalize_text],
 
 
 
 
 
193
  outputs=[audio_gr, out_text_gr],
194
  api_name="predict",
195
  )
196
 
197
  demo.queue()
198
- demo.launch(debug=True, show_api=True)
 
7
  from io import StringIO
8
 
9
  import gradio as gr
10
+ import spaces
11
  import torch
12
  import torchaudio
13
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
 
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
 
18
+ # download for mecab
19
+ os.system("python -m unidic download")
20
+
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
24
+ # This will trigger downloading model
25
+ print("Downloading if not downloaded viXTTS")
26
  checkpoint_dir = "model/"
27
  repo_id = "capleaf/viXTTS"
28
  use_deepspeed = False
 
43
  local_dir=checkpoint_dir,
44
  )
45
 
 
46
  xtts_config = os.path.join(checkpoint_dir, "config.json")
47
  config = XttsConfig()
48
  config.load_json(xtts_config)
 
53
  if torch.cuda.is_available():
54
  MODEL.cuda()
55
 
 
56
  supported_languages = config.languages
57
+ if not "vi" in supported_languages:
58
  supported_languages.append("vi")
59
 
60
 
 
75
 
76
 
77
  def calculate_keep_len(text, lang):
78
+ """Simple hack for short sentences"""
79
  if lang in ["ja", "zh-cn"]:
80
  return -1
81
 
 
89
  return -1
90
 
91
 
92
+ @spaces.GPU
93
+ def predict(
94
+ prompt,
95
+ language,
96
+ audio_file_pth,
97
+ normalize_text=True,
98
+ ):
99
  if language not in supported_languages:
100
  metrics_text = gr.Warning(
101
+ f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
102
  )
103
+
104
+ return (None, metrics_text)
105
+
106
+ speaker_wav = audio_file_pth
107
 
108
  if len(prompt) < 2:
109
+ metrics_text = gr.Warning("Please give a longer prompt text")
110
+ return (None, metrics_text)
111
+
112
+ # if len(prompt) > 250:
113
+ # metrics_text = gr.Warning(
114
+ # str(len(prompt))
115
+ # + " characters.\n"
116
+ # + "Your prompt is too long, please keep it under 250 characters\n"
117
+ # + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
118
+ # )
119
+ # return (None, metrics_text)
120
 
121
  try:
122
  metrics_text = ""
123
  t_latent = time.time()
124
 
125
  try:
126
+ (
127
+ gpt_cond_latent,
128
+ speaker_embedding,
129
+ ) = MODEL.get_conditioning_latents(
130
+ audio_path=speaker_wav,
131
  gpt_cond_len=30,
132
  gpt_cond_chunk_len=4,
133
  max_ref_length=60,
134
  )
135
+
136
  except Exception as e:
137
+ print("Speaker encoding error", str(e))
138
+ metrics_text = gr.Warning(
139
+ "It appears something wrong with reference, did you unmute your microphone?"
140
+ )
141
+ return (None, metrics_text)
142
 
143
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
144
 
145
  if normalize_text and language == "vi":
146
  prompt = normalize_vietnamese_text(prompt)
147
 
148
+ print("I: Generating new audio...")
149
  t0 = time.time()
150
  out = MODEL.inference(
151
  prompt,
 
157
  enable_text_splitting=True,
158
  )
159
  inference_time = time.time() - t0
160
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
161
+ metrics_text += (
162
+ f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
163
+ )
164
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
165
+ print(f"Real-time factor (RTF): {real_time_factor}")
166
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
167
 
168
+ # Temporary hack for short sentences
169
  keep_len = calculate_keep_len(prompt, language)
170
  out["wav"] = out["wav"][:keep_len]
171
 
172
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
173
 
174
  except RuntimeError as e:
175
+ if "device-side assert" in str(e):
176
+ # cannot do anything on cuda device side error, need to restart
177
+ print(
178
+ f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
179
+ flush=True,
180
+ )
181
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
182
+ print("Cuda device-assert Runtime encountered need restart")
183
+
184
+ error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
185
+ error_data = [
186
+ error_time,
187
+ prompt,
188
+ language,
189
+ audio_file_pth,
190
+ ]
191
+ error_data = [str(e) if type(e) != str else e for e in error_data]
192
+ print(error_data)
193
+ print(speaker_wav)
194
+ write_io = StringIO()
195
+ csv.writer(write_io).writerows([error_data])
196
+ csv_upload = write_io.getvalue().encode()
197
+
198
+ filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
+ print("Writing error csv")
200
+ error_api = HfApi()
201
+ error_api.upload_file(
202
+ path_or_fileobj=csv_upload,
203
+ path_in_repo=filename,
204
+ repo_id="coqui/xtts-flagged-dataset",
205
+ repo_type="dataset",
206
+ )
207
 
208
+ # speaker_wav
209
+ print("Writing error reference audio")
210
+ speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
211
+ error_api = HfApi()
212
+ error_api.upload_file(
213
+ path_or_fileobj=speaker_wav,
214
+ path_in_repo=speaker_filename,
215
+ repo_id="coqui/xtts-flagged-dataset",
216
+ repo_type="dataset",
217
+ )
218
 
219
+ # HF Space specific.. This error is unrecoverable need to restart space
220
+ space = api.get_space_runtime(repo_id=repo_id)
221
+ if space.stage != "BUILDING":
222
+ api.restart_space(repo_id=repo_id)
223
+ else:
224
+ print("TRIED TO RESTART but space is building")
225
+
226
+ else:
227
+ if "Failed to decode" in str(e):
228
+ print("Speaker encoding error", str(e))
229
+ metrics_text = gr.Warning(
230
+ metrics_text="It appears something wrong with reference, did you unmute your microphone?"
231
+ )
232
+ else:
233
+ print("RuntimeError: non device-side assert error:", str(e))
234
+ metrics_text = gr.Warning(
235
+ "Something unexpected happened please retry again."
236
+ )
237
+ return (None, metrics_text)
238
+ return ("output.wav", metrics_text)
239
 
 
240
 
241
  with gr.Blocks(analytics_enabled=False) as demo:
242
  with gr.Row():
243
  with gr.Column():
244
+ gr.Markdown(
245
+ """
246
+ # viXTTS Demo ✨
247
+ - Github: https://github.com/thinhlpg/vixtts-demo/
248
+ """
249
+ )
250
  with gr.Column():
251
+ # placeholder to align the image
252
  pass
253
 
254
  with gr.Row():
255
  with gr.Column():
256
  input_text_gr = gr.Textbox(
257
+ label="Text Prompt (Văn bản cần đọc)",
258
+ info="Mỗi câu nên từ 10 từ trở lên.",
259
+ value="Xin chào, tôi một hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
260
  )
261
  language_gr = gr.Dropdown(
262
+ label="Language (Ngôn ngữ)",
263
+ choices=[
264
+ "vi",
265
+ "en",
266
+ "es",
267
+ "fr",
268
+ "de",
269
+ "it",
270
+ "pt",
271
+ "pl",
272
+ "tr",
273
+ "ru",
274
+ "nl",
275
+ "cs",
276
+ "ar",
277
+ "zh-cn",
278
+ "ja",
279
+ "ko",
280
+ "hu",
281
+ "hi",
282
+ ],
283
+ max_choices=1,
284
  value="vi",
285
  )
286
  normalize_text = gr.Checkbox(
287
+ label="Chuẩn hóa văn bản tiếng Việt",
288
+ info="Normalize Vietnamese text",
289
  value=True,
290
  )
291
  ref_gr = gr.Audio(
292
+ label="Reference Audio (Giọng mẫu)",
 
293
  type="filepath",
294
+ value="model/samples/nu-luu-loat.wav",
295
+ )
296
+ tts_button = gr.Button(
297
+ "Đọc 🗣️🔥",
298
+ elem_id="send-btn",
299
+ visible=True,
300
+ variant="primary",
301
  )
 
302
 
303
  with gr.Column():
304
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
305
+ out_text_gr = gr.Text(label="Metrics")
306
 
307
  tts_button.click(
308
  predict,
309
+ [
310
+ input_text_gr,
311
+ language_gr,
312
+ ref_gr,
313
+ normalize_text,
314
+ ],
315
  outputs=[audio_gr, out_text_gr],
316
  api_name="predict",
317
  )
318
 
319
  demo.queue()
320
+ demo.launch(debug=True, show_api=True, share=True)