TDN-M commited on
Commit
b8fd884
·
verified ·
1 Parent(s): 08179b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -5,7 +5,6 @@ import re
5
  import time
6
  import uuid
7
  from io import StringIO
8
-
9
  import gradio as gr
10
  import spaces
11
  import torch
@@ -14,10 +13,10 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
 
17
 
18
  # download for mecab
19
  os.system("python -m unidic download")
20
-
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
@@ -26,9 +25,7 @@ print("Downloading if not downloaded viXTTS")
26
  checkpoint_dir = "model/"
27
  repo_id = "capleaf/viXTTS"
28
  use_deepspeed = False
29
-
30
  os.makedirs(checkpoint_dir, exist_ok=True)
31
-
32
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
33
  files_in_dir = os.listdir(checkpoint_dir)
34
  if not all(file in files_in_dir for file in required_files):
@@ -42,7 +39,6 @@ if not all(file in files_in_dir for file in required_files):
42
  filename="speakers_xtts.pth",
43
  local_dir=checkpoint_dir,
44
  )
45
-
46
  xtts_config = os.path.join(checkpoint_dir, "config.json")
47
  config = XttsConfig()
48
  config.load_json(xtts_config)
@@ -52,12 +48,10 @@ MODEL.load_checkpoint(
52
  )
53
  if torch.cuda.is_available():
54
  MODEL.cuda()
55
-
56
  supported_languages = config.languages
57
  if not "vi" in supported_languages:
58
  supported_languages.append("vi")
59
 
60
-
61
  def normalize_vietnamese_text(text):
62
  text = (
63
  TTSnorm(text, unknown=False, lower=False, rule=True)
@@ -70,59 +64,52 @@ def normalize_vietnamese_text(text):
70
  .replace("'", "")
71
  .replace("AI", "Ây Ai")
72
  .replace("A.I", "Ây Ai")
73
- .replace("%"), "phần trăm"
74
  )
75
  return text
76
 
77
-
78
  def calculate_keep_len(text, lang):
79
  """Simple hack for short sentences"""
80
  if lang in ["ja", "zh-cn"]:
81
  return -1
82
-
83
  word_count = len(text.split())
84
  num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
85
-
86
  if word_count < 5:
87
  return 15000 * word_count + 2000 * num_punct
88
  elif word_count < 10:
89
  return 13000 * word_count + 2000 * num_punct
90
  return -1
91
 
92
-
93
  @spaces.GPU
94
  def predict(
95
  prompt,
96
  language,
97
  audio_file_pth,
98
  normalize_text=True,
 
 
99
  ):
 
 
 
 
 
 
 
100
  if language not in supported_languages:
101
  metrics_text = gr.Warning(
102
- f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
103
  )
104
-
105
  return (None, metrics_text)
106
 
107
  speaker_wav = audio_file_pth
108
-
109
  if len(prompt) < 2:
110
  metrics_text = gr.Warning("Please give a longer prompt text")
111
  return (None, metrics_text)
112
 
113
- # if len(prompt) > 250:
114
- # metrics_text = gr.Warning(
115
- # str(len(prompt))
116
- # + " characters.\n"
117
- # + "Your prompt is too long, please keep it under 250 characters\n"
118
- # + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
119
- # )
120
- # return (None, metrics_text)
121
-
122
  try:
123
  metrics_text = ""
124
  t_latent = time.time()
125
-
126
  try:
127
  (
128
  gpt_cond_latent,
@@ -133,7 +120,6 @@ def predict(
133
  gpt_cond_chunk_len=4,
134
  max_ref_length=60,
135
  )
136
-
137
  except Exception as e:
138
  print("Speaker encoding error", str(e))
139
  metrics_text = gr.Warning(
@@ -142,10 +128,8 @@ def predict(
142
  return (None, metrics_text)
143
 
144
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
145
-
146
  if normalize_text and language == "vi":
147
  prompt = normalize_vietnamese_text(prompt)
148
-
149
  print("I: Generating new audio...")
150
  t0 = time.time()
151
  out = MODEL.inference(
@@ -169,9 +153,7 @@ def predict(
169
  # Temporary hack for short sentences
170
  keep_len = calculate_keep_len(prompt, language)
171
  out["wav"] = out["wav"][:keep_len]
172
-
173
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
174
-
175
  except RuntimeError as e:
176
  if "device-side assert" in str(e):
177
  # cannot do anything on cuda device side error, need to restart
@@ -181,7 +163,6 @@ def predict(
181
  )
182
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
183
  print("Cuda device-assert Runtime encountered need restart")
184
-
185
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
186
  error_data = [
187
  error_time,
@@ -195,7 +176,6 @@ def predict(
195
  write_io = StringIO()
196
  csv.writer(write_io).writerows([error_data])
197
  csv_upload = write_io.getvalue().encode()
198
-
199
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
200
  print("Writing error csv")
201
  error_api = HfApi()
@@ -205,7 +185,6 @@ def predict(
205
  repo_id="coqui/xtts-flagged-dataset",
206
  repo_type="dataset",
207
  )
208
-
209
  # speaker_wav
210
  print("Writing error reference audio")
211
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
@@ -216,19 +195,17 @@ def predict(
216
  repo_id="coqui/xtts-flagged-dataset",
217
  repo_type="dataset",
218
  )
219
-
220
  # HF Space specific.. This error is unrecoverable need to restart space
221
  space = api.get_space_runtime(repo_id=repo_id)
222
  if space.stage != "BUILDING":
223
  api.restart_space(repo_id=repo_id)
224
  else:
225
  print("TRIED TO RESTART but space is building")
226
-
227
  else:
228
  if "Failed to decode" in str(e):
229
  print("Speaker encoding error", str(e))
230
  metrics_text = gr.Warning(
231
- metrics_text="It appears something wrong with reference, did you unmute your microphone?"
232
  )
233
  else:
234
  print("RuntimeError: non device-side assert error:", str(e))
@@ -238,7 +215,7 @@ def predict(
238
  return (None, metrics_text)
239
  return ("output.wav", metrics_text)
240
 
241
-
242
  with gr.Blocks(analytics_enabled=False) as demo:
243
  with gr.Row():
244
  with gr.Column():
@@ -288,6 +265,16 @@ with gr.Blocks(analytics_enabled=False) as demo:
288
  info="Normalize Vietnamese text",
289
  value=True,
290
  )
 
 
 
 
 
 
 
 
 
 
291
  ref_gr = gr.Audio(
292
  label="Reference Audio (Giọng mẫu)",
293
  type="filepath",
@@ -311,6 +298,8 @@ with gr.Blocks(analytics_enabled=False) as demo:
311
  language_gr,
312
  ref_gr,
313
  normalize_text,
 
 
314
  ],
315
  outputs=[audio_gr, out_text_gr],
316
  api_name="predict",
 
5
  import time
6
  import uuid
7
  from io import StringIO
 
8
  import gradio as gr
9
  import spaces
10
  import torch
 
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
+ from content_generation import create_content # Nhập hàm create_content từ file content_generation.py
17
 
18
  # download for mecab
19
  os.system("python -m unidic download")
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  api = HfApi(token=HF_TOKEN)
22
 
 
25
  checkpoint_dir = "model/"
26
  repo_id = "capleaf/viXTTS"
27
  use_deepspeed = False
 
28
  os.makedirs(checkpoint_dir, exist_ok=True)
 
29
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
30
  files_in_dir = os.listdir(checkpoint_dir)
31
  if not all(file in files_in_dir for file in required_files):
 
39
  filename="speakers_xtts.pth",
40
  local_dir=checkpoint_dir,
41
  )
 
42
  xtts_config = os.path.join(checkpoint_dir, "config.json")
43
  config = XttsConfig()
44
  config.load_json(xtts_config)
 
48
  )
49
  if torch.cuda.is_available():
50
  MODEL.cuda()
 
51
  supported_languages = config.languages
52
  if not "vi" in supported_languages:
53
  supported_languages.append("vi")
54
 
 
55
  def normalize_vietnamese_text(text):
56
  text = (
57
  TTSnorm(text, unknown=False, lower=False, rule=True)
 
64
  .replace("'", "")
65
  .replace("AI", "Ây Ai")
66
  .replace("A.I", "Ây Ai")
67
+ .replace("%", "phần trăm")
68
  )
69
  return text
70
 
 
71
  def calculate_keep_len(text, lang):
72
  """Simple hack for short sentences"""
73
  if lang in ["ja", "zh-cn"]:
74
  return -1
 
75
  word_count = len(text.split())
76
  num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
 
77
  if word_count < 5:
78
  return 15000 * word_count + 2000 * num_punct
79
  elif word_count < 10:
80
  return 13000 * word_count + 2000 * num_punct
81
  return -1
82
 
 
83
  @spaces.GPU
84
  def predict(
85
  prompt,
86
  language,
87
  audio_file_pth,
88
  normalize_text=True,
89
+ use_llm=False, # Thêm tùy chọn sử dụng LLM
90
+ content_type="Theo yêu cầu", # Loại nội dung (ví dụ: "triết lý sống" hoặc "Theo yêu cầu")
91
  ):
92
+ if use_llm:
93
+ # Nếu sử dụng LLM, tạo nội dung văn bản từ đầu vào
94
+ print("I: Generating text with LLM...")
95
+ generated_text = create_content(prompt, content_type, language)
96
+ print(f"Generated text: {generated_text}")
97
+ prompt = generated_text # Gán văn bản được tạo bởi LLM vào biến prompt
98
+
99
  if language not in supported_languages:
100
  metrics_text = gr.Warning(
101
+ f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
102
  )
 
103
  return (None, metrics_text)
104
 
105
  speaker_wav = audio_file_pth
 
106
  if len(prompt) < 2:
107
  metrics_text = gr.Warning("Please give a longer prompt text")
108
  return (None, metrics_text)
109
 
 
 
 
 
 
 
 
 
 
110
  try:
111
  metrics_text = ""
112
  t_latent = time.time()
 
113
  try:
114
  (
115
  gpt_cond_latent,
 
120
  gpt_cond_chunk_len=4,
121
  max_ref_length=60,
122
  )
 
123
  except Exception as e:
124
  print("Speaker encoding error", str(e))
125
  metrics_text = gr.Warning(
 
128
  return (None, metrics_text)
129
 
130
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
 
131
  if normalize_text and language == "vi":
132
  prompt = normalize_vietnamese_text(prompt)
 
133
  print("I: Generating new audio...")
134
  t0 = time.time()
135
  out = MODEL.inference(
 
153
  # Temporary hack for short sentences
154
  keep_len = calculate_keep_len(prompt, language)
155
  out["wav"] = out["wav"][:keep_len]
 
156
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
 
157
  except RuntimeError as e:
158
  if "device-side assert" in str(e):
159
  # cannot do anything on cuda device side error, need to restart
 
163
  )
164
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
165
  print("Cuda device-assert Runtime encountered need restart")
 
166
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
167
  error_data = [
168
  error_time,
 
176
  write_io = StringIO()
177
  csv.writer(write_io).writerows([error_data])
178
  csv_upload = write_io.getvalue().encode()
 
179
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
180
  print("Writing error csv")
181
  error_api = HfApi()
 
185
  repo_id="coqui/xtts-flagged-dataset",
186
  repo_type="dataset",
187
  )
 
188
  # speaker_wav
189
  print("Writing error reference audio")
190
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
 
195
  repo_id="coqui/xtts-flagged-dataset",
196
  repo_type="dataset",
197
  )
 
198
  # HF Space specific.. This error is unrecoverable need to restart space
199
  space = api.get_space_runtime(repo_id=repo_id)
200
  if space.stage != "BUILDING":
201
  api.restart_space(repo_id=repo_id)
202
  else:
203
  print("TRIED TO RESTART but space is building")
 
204
  else:
205
  if "Failed to decode" in str(e):
206
  print("Speaker encoding error", str(e))
207
  metrics_text = gr.Warning(
208
+ "It appears something wrong with reference, did you unmute your microphone?"
209
  )
210
  else:
211
  print("RuntimeError: non device-side assert error:", str(e))
 
215
  return (None, metrics_text)
216
  return ("output.wav", metrics_text)
217
 
218
+ # Cập nhật giao diện Gradio
219
  with gr.Blocks(analytics_enabled=False) as demo:
220
  with gr.Row():
221
  with gr.Column():
 
265
  info="Normalize Vietnamese text",
266
  value=True,
267
  )
268
+ use_llm_checkbox = gr.Checkbox(
269
+ label="Sử dụng LLM để tạo nội dung",
270
+ info="Use LLM to generate content",
271
+ value=False,
272
+ )
273
+ content_type_dropdown = gr.Dropdown(
274
+ label="Loại nội dung",
275
+ choices=["triết lý sống", "Theo yêu cầu"],
276
+ value="Theo yêu cầu",
277
+ )
278
  ref_gr = gr.Audio(
279
  label="Reference Audio (Giọng mẫu)",
280
  type="filepath",
 
298
  language_gr,
299
  ref_gr,
300
  normalize_text,
301
+ use_llm_checkbox, # Thêm checkbox để bật/tắt LLM
302
+ content_type_dropdown, # Thêm dropdown để chọn loại nội dung
303
  ],
304
  outputs=[audio_gr, out_text_gr],
305
  api_name="predict",