TDN-M commited on
Commit
37fb699
·
verified ·
1 Parent(s): 8554efb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -43
app.py CHANGED
@@ -2,9 +2,10 @@ import csv
2
  import datetime
3
  import os
4
  import re
 
5
  import time
6
  import uuid
7
- from io import StringIO
8
  import gradio as gr
9
  import spaces
10
  import torch
@@ -14,13 +15,20 @@ from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
  from content_generation import create_content # Nhập hàm create_content từ file content_generation.py
 
 
 
 
 
17
 
18
- # download for mecab
19
  os.system("python -m unidic download")
 
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  api = HfApi(token=HF_TOKEN)
22
 
23
- # This will trigger downloading model
24
  print("Downloading if not downloaded viXTTS")
25
  checkpoint_dir = "model/"
26
  repo_id = "capleaf/viXTTS"
@@ -39,6 +47,7 @@ if not all(file in files_in_dir for file in required_files):
39
  filename="speakers_xtts.pth",
40
  local_dir=checkpoint_dir,
41
  )
 
42
  xtts_config = os.path.join(checkpoint_dir, "config.json")
43
  config = XttsConfig()
44
  config.load_json(xtts_config)
@@ -48,10 +57,12 @@ MODEL.load_checkpoint(
48
  )
49
  if torch.cuda.is_available():
50
  MODEL.cuda()
 
51
  supported_languages = config.languages
52
- if not "vi" in supported_languages:
53
  supported_languages.append("vi")
54
 
 
55
  def normalize_vietnamese_text(text):
56
  text = (
57
  TTSnorm(text, unknown=False, lower=False, rule=True)
@@ -68,6 +79,7 @@ def normalize_vietnamese_text(text):
68
  )
69
  return text
70
 
 
71
  def calculate_keep_len(text, lang):
72
  """Simple hack for short sentences"""
73
  if lang in ["ja", "zh-cn"]:
@@ -80,33 +92,166 @@ def calculate_keep_len(text, lang):
80
  return 13000 * word_count + 2000 * num_punct
81
  return -1
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @spaces.GPU
84
  def predict(
85
  prompt,
86
  language,
87
  audio_file_pth,
88
  normalize_text=True,
89
- use_llm=False, # Thêm tùy chọn sử dụng LLM
90
- content_type="Theo yêu cầu", # Loại nội dung (ví dụ: "triết lý sống" hoặc "Theo yêu cầu")
91
  ):
92
  if use_llm:
93
- # Nếu sử dụng LLM, tạo nội dung văn bản từ đầu vào
94
  print("I: Generating text with LLM...")
95
  generated_text = create_content(prompt, content_type, language)
96
  print(f"Generated text: {generated_text}")
97
- prompt = generated_text # Gán văn bản được tạo bởi LLM vào biến prompt
98
-
99
  if language not in supported_languages:
100
  metrics_text = gr.Warning(
101
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
102
  )
103
  return (None, metrics_text)
104
-
105
  speaker_wav = audio_file_pth
106
  if len(prompt) < 2:
107
  metrics_text = gr.Warning("Please give a longer prompt text")
108
  return (None, metrics_text)
109
-
110
  try:
111
  metrics_text = ""
112
  t_latent = time.time()
@@ -126,7 +271,6 @@ def predict(
126
  "It appears something wrong with reference, did you unmute your microphone?"
127
  )
128
  return (None, metrics_text)
129
-
130
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
131
  if normalize_text and language == "vi":
132
  prompt = normalize_vietnamese_text(prompt)
@@ -149,14 +293,11 @@ def predict(
149
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
150
  print(f"Real-time factor (RTF): {real_time_factor}")
151
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
152
-
153
- # Temporary hack for short sentences
154
  keep_len = calculate_keep_len(prompt, language)
155
  out["wav"] = out["wav"][:keep_len]
156
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
157
  except RuntimeError as e:
158
  if "device-side assert" in str(e):
159
- # cannot do anything on cuda device side error, need to restart
160
  print(
161
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
162
  flush=True,
@@ -185,8 +326,6 @@ def predict(
185
  repo_id="coqui/xtts-flagged-dataset",
186
  repo_type="dataset",
187
  )
188
- # speaker_wav
189
- print("Writing error reference audio")
190
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
191
  error_api = HfApi()
192
  error_api.upload_file(
@@ -195,7 +334,6 @@ def predict(
195
  repo_id="coqui/xtts-flagged-dataset",
196
  repo_type="dataset",
197
  )
198
- # HF Space specific.. This error is unrecoverable need to restart space
199
  space = api.get_space_runtime(repo_id=repo_id)
200
  if space.stage != "BUILDING":
201
  api.restart_space(repo_id=repo_id)
@@ -215,7 +353,7 @@ def predict(
215
  return (None, metrics_text)
216
  return ("output.wav", metrics_text)
217
 
218
- # Cập nhật giao diện Gradio
219
  with gr.Blocks(analytics_enabled=False) as demo:
220
  with gr.Row():
221
  with gr.Column():
@@ -225,9 +363,8 @@ with gr.Blocks(analytics_enabled=False) as demo:
225
  """
226
  )
227
  with gr.Column():
228
- # placeholder to align the image
229
  pass
230
-
231
  with gr.Row():
232
  with gr.Column():
233
  input_text_gr = gr.Textbox(
@@ -238,24 +375,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
238
  language_gr = gr.Dropdown(
239
  label="Language (Ngôn ngữ)",
240
  choices=[
241
- "vi",
242
- "en",
243
- "es",
244
- "fr",
245
- "de",
246
- "it",
247
- "pt",
248
- "pl",
249
- "tr",
250
- "ru",
251
- "nl",
252
- "cs",
253
- "ar",
254
- "zh-cn",
255
- "ja",
256
- "ko",
257
- "hu",
258
- "hi",
259
  ],
260
  max_choices=1,
261
  value="vi",
@@ -286,11 +406,14 @@ with gr.Blocks(analytics_enabled=False) as demo:
286
  visible=True,
287
  variant="primary",
288
  )
289
-
290
  with gr.Column():
291
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
292
  out_text_gr = gr.Text(label="Metrics")
293
-
 
 
 
294
  tts_button.click(
295
  predict,
296
  [
@@ -298,11 +421,20 @@ with gr.Blocks(analytics_enabled=False) as demo:
298
  language_gr,
299
  ref_gr,
300
  normalize_text,
301
- use_llm_checkbox, # Thêm checkbox để bật/tắt LLM
302
- content_type_dropdown, # Thêm dropdown để chọn loại nội dung
303
  ],
304
  outputs=[audio_gr, out_text_gr],
305
  api_name="predict",
 
 
 
 
 
 
 
 
 
306
  )
307
 
308
  demo.queue()
 
2
  import datetime
3
  import os
4
  import re
5
+ import subprocess
6
  import time
7
  import uuid
8
+ from io import BytesIO, StringIO
9
  import gradio as gr
10
  import spaces
11
  import torch
 
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
  from content_generation import create_content # Nhập hàm create_content từ file content_generation.py
18
+ from PIL import Image
19
+ from pathlib import Path
20
+ import requests
21
+ import json
22
+ import hashlib
23
 
24
+ # Download for mecab
25
  os.system("python -m unidic download")
26
+
27
+ # Cấu hình API và mô hình
28
  HF_TOKEN = os.environ.get("HF_TOKEN")
29
  api = HfApi(token=HF_TOKEN)
30
 
31
+ # Tải hình viXTTS
32
  print("Downloading if not downloaded viXTTS")
33
  checkpoint_dir = "model/"
34
  repo_id = "capleaf/viXTTS"
 
47
  filename="speakers_xtts.pth",
48
  local_dir=checkpoint_dir,
49
  )
50
+
51
  xtts_config = os.path.join(checkpoint_dir, "config.json")
52
  config = XttsConfig()
53
  config.load_json(xtts_config)
 
57
  )
58
  if torch.cuda.is_available():
59
  MODEL.cuda()
60
+
61
  supported_languages = config.languages
62
+ if "vi" not in supported_languages:
63
  supported_languages.append("vi")
64
 
65
+ # Hàm chuẩn hóa văn bản tiếng Việt
66
  def normalize_vietnamese_text(text):
67
  text = (
68
  TTSnorm(text, unknown=False, lower=False, rule=True)
 
79
  )
80
  return text
81
 
82
+ # Hàm tính toán độ dài giữ lại cho audio ngắn
83
  def calculate_keep_len(text, lang):
84
  """Simple hack for short sentences"""
85
  if lang in ["ja", "zh-cn"]:
 
92
  return 13000 * word_count + 2000 * num_punct
93
  return -1
94
 
95
+ # Hàm tạo mô tả ảnh từ nội dung audio
96
+ def generate_image_description(prompt):
97
+ return f"A visual representation of: {prompt}"
98
+
99
+ # Hàm gọi API tạo ảnh
100
+ def txt2img(prompt, width, height):
101
+ model_id = "770694094415489962" # Model ID cố định
102
+ vae_id = "sdxl-vae-fp16-fix.safetensors" # VAE cố định
103
+ lora_items = [
104
+ {"loraModel": "766419665653268679", "weight": 0.7},
105
+ {"loraModel": "777630084346589138", "weight": 0.7},
106
+ {"loraModel": "776587863287492519", "weight": 0.7}
107
+ ]
108
+ txt2img_data = {
109
+ "request_id": hashlib.md5(str(int(time.time())).encode()).hexdigest(),
110
+ "stages": [
111
+ {
112
+ "type": "INPUT_INITIALIZE",
113
+ "inputInitialize": {
114
+ "seed": -1,
115
+ "count": 1
116
+ }
117
+ },
118
+ {
119
+ "type": "DIFFUSION",
120
+ "diffusion": {
121
+ "width": width,
122
+ "height": height,
123
+ "prompts": [
124
+ {
125
+ "text": prompt
126
+ }
127
+ ],
128
+ "negativePrompts": [
129
+ {
130
+ "text": "nsfw"
131
+ }
132
+ ],
133
+ "sdModel": model_id,
134
+ "sdVae": vae_id,
135
+ "sampler": "Euler a",
136
+ "steps": 20,
137
+ "cfgScale": 3,
138
+ "clipSkip": 1,
139
+ "etaNoiseSeedDelta": 31337,
140
+ "lora": {
141
+ "items": lora_items
142
+ }
143
+ }
144
+ }
145
+ ]
146
+ }
147
+ body = json.dumps(txt2img_data)
148
+ headers = {
149
+ 'Content-Type': 'application/json',
150
+ 'Accept': 'application/json',
151
+ 'Authorization': f'Bearer {os.getenv("api_key_token")}'
152
+ }
153
+ response = requests.post(f"https://ap-east-1.tensorart.cloud/v1/jobs", json=txt2img_data, headers=headers)
154
+ if response.status_code != 200:
155
+ return f"Error: {response.status_code} - {response.text}"
156
+ response_data = response.json()
157
+ job_id = response_data['job']['id']
158
+ print(f"Job created. ID: {job_id}")
159
+ start_time = time.time()
160
+ timeout = 300 # Giới hạn thời gian chờ là 300 giây (5 phút)
161
+ while True:
162
+ time.sleep(10)
163
+ elapsed_time = time.time() - start_time
164
+ if elapsed_time > timeout:
165
+ return f"Error: Job timed out after {timeout} seconds."
166
+ response = requests.get(f"https://ap-east-1.tensorart.cloud/v1/jobs/{job_id}", headers=headers)
167
+ if response.status_code != 200:
168
+ return f"Error: {response.status_code} - {response.text}"
169
+ get_job_response_data = response.json()
170
+ job_status = get_job_response_data['job']['status']
171
+ print(f"Job status: {job_status}")
172
+ if job_status == 'SUCCESS':
173
+ if 'successInfo' in get_job_response_data['job']:
174
+ image_url = get_job_response_data['job']['successInfo']['images'][0]['url']
175
+ print(f"Job succeeded. Image URL: {image_url}")
176
+ response_image = requests.get(image_url)
177
+ img = Image.open(BytesIO(response_image.content))
178
+ return img
179
+ else:
180
+ return "Error: Output is missing in the job response."
181
+ elif job_status == 'FAILED':
182
+ return "Error: Job failed. Please try again with different settings."
183
+
184
+ # Hàm tạo video từ ảnh và audio
185
+ def create_video(image_path, audio_path, output_path):
186
+ command = [
187
+ "ffmpeg",
188
+ "-i", image_path,
189
+ "-i", audio_path,
190
+ "-filter_complex",
191
+ "[1:a]aformat=channel_layouts=mono,showwaves=s=800x250:mode=line:[email protected][w];[0:v][w]overlay=(W-w)/2:(H-h)/2",
192
+ "-c:v", "libx264",
193
+ "-c:a", "aac",
194
+ "-y", output_path
195
+ ]
196
+ subprocess.run(command, check=True)
197
+
198
+ # Hàm xử lý sự kiện khi nhấn nút "Tạo Video"
199
+ def generate_video(audio_file, prompt):
200
+ if not os.path.exists(audio_file):
201
+ return None, "Audio file not found. Please generate audio first."
202
+
203
+ # Bước 1: Tạo mô tả ảnh
204
+ image_description = generate_image_description(prompt)
205
+
206
+ # Bước 2: Gọi API tạo ảnh
207
+ try:
208
+ image = txt2img(image_description, width=800, height=600)
209
+ if isinstance(image, str): # Nếu có lỗi từ API
210
+ return None, image
211
+
212
+ # Lưu ảnh vào thư mục
213
+ image_path = os.path.join(SAVE_DIR, "generated_image.png")
214
+ image.save(image_path)
215
+ except Exception as e:
216
+ return None, f"Error generating image: {str(e)}"
217
+
218
+ # Bước 3: Tạo video từ ảnh và audio
219
+ video_output_path = os.path.join(SAVE_DIR, "output_video.mp4")
220
+ try:
221
+ create_video(image_path, audio_file, video_output_path)
222
+ except Exception as e:
223
+ return None, f"Error creating video: {str(e)}"
224
+
225
+ return video_output_path, "Video created successfully!"
226
+
227
+ # Thư mục lưu trữ ảnh và video
228
+ SAVE_DIR = "generated_images"
229
+ Path(SAVE_DIR).mkdir(exist_ok=True)
230
+
231
+ # Hàm dự đoán và tạo audio
232
  @spaces.GPU
233
  def predict(
234
  prompt,
235
  language,
236
  audio_file_pth,
237
  normalize_text=True,
238
+ use_llm=False,
239
+ content_type="Theo yêu cầu",
240
  ):
241
  if use_llm:
 
242
  print("I: Generating text with LLM...")
243
  generated_text = create_content(prompt, content_type, language)
244
  print(f"Generated text: {generated_text}")
245
+ prompt = generated_text
 
246
  if language not in supported_languages:
247
  metrics_text = gr.Warning(
248
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
249
  )
250
  return (None, metrics_text)
 
251
  speaker_wav = audio_file_pth
252
  if len(prompt) < 2:
253
  metrics_text = gr.Warning("Please give a longer prompt text")
254
  return (None, metrics_text)
 
255
  try:
256
  metrics_text = ""
257
  t_latent = time.time()
 
271
  "It appears something wrong with reference, did you unmute your microphone?"
272
  )
273
  return (None, metrics_text)
 
274
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
275
  if normalize_text and language == "vi":
276
  prompt = normalize_vietnamese_text(prompt)
 
293
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
294
  print(f"Real-time factor (RTF): {real_time_factor}")
295
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
 
 
296
  keep_len = calculate_keep_len(prompt, language)
297
  out["wav"] = out["wav"][:keep_len]
298
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
299
  except RuntimeError as e:
300
  if "device-side assert" in str(e):
 
301
  print(
302
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
303
  flush=True,
 
326
  repo_id="coqui/xtts-flagged-dataset",
327
  repo_type="dataset",
328
  )
 
 
329
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
330
  error_api = HfApi()
331
  error_api.upload_file(
 
334
  repo_id="coqui/xtts-flagged-dataset",
335
  repo_type="dataset",
336
  )
 
337
  space = api.get_space_runtime(repo_id=repo_id)
338
  if space.stage != "BUILDING":
339
  api.restart_space(repo_id=repo_id)
 
353
  return (None, metrics_text)
354
  return ("output.wav", metrics_text)
355
 
356
+ # Giao diện Gradio
357
  with gr.Blocks(analytics_enabled=False) as demo:
358
  with gr.Row():
359
  with gr.Column():
 
363
  """
364
  )
365
  with gr.Column():
 
366
  pass
367
+
368
  with gr.Row():
369
  with gr.Column():
370
  input_text_gr = gr.Textbox(
 
375
  language_gr = gr.Dropdown(
376
  label="Language (Ngôn ngữ)",
377
  choices=[
378
+ "vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  ],
380
  max_choices=1,
381
  value="vi",
 
406
  visible=True,
407
  variant="primary",
408
  )
409
+
410
  with gr.Column():
411
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
412
  out_text_gr = gr.Text(label="Metrics")
413
+ video_button = gr.Button("Tạo Video 🎥", visible=False)
414
+ video_output = gr.Video(label="Generated Video", visible=False)
415
+ video_status = gr.Text(label="Video Status")
416
+
417
  tts_button.click(
418
  predict,
419
  [
 
421
  language_gr,
422
  ref_gr,
423
  normalize_text,
424
+ use_llm_checkbox,
425
+ content_type_dropdown,
426
  ],
427
  outputs=[audio_gr, out_text_gr],
428
  api_name="predict",
429
+ ).then(
430
+ lambda: [gr.update(visible=True)],
431
+ outputs=[video_button]
432
+ )
433
+
434
+ video_button.click(
435
+ generate_video,
436
+ inputs=[audio_gr, input_text_gr],
437
+ outputs=[video_output, video_status],
438
  )
439
 
440
  demo.queue()