Update app.py
Browse files
app.py
CHANGED
@@ -3,47 +3,59 @@
|
|
3 |
@author:XuMing([email protected])
|
4 |
@description: Re-train by TWMAN
|
5 |
"""
|
|
|
6 |
import hashlib
|
7 |
import os
|
8 |
import ssl
|
|
|
9 |
|
10 |
import gradio as gr
|
11 |
import torch
|
12 |
from loguru import logger
|
13 |
-
import
|
14 |
|
|
|
15 |
ssl._create_default_https_context = ssl._create_unverified_context
|
16 |
-
import nltk
|
17 |
|
18 |
-
|
|
|
19 |
try:
|
20 |
-
subprocess.check_call([
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
22 |
except subprocess.CalledProcessError:
|
23 |
-
print("
|
24 |
|
25 |
-
|
26 |
-
upgrade_LangSegment()
|
27 |
|
28 |
-
#
|
29 |
nltk_data_path = os.path.expanduser('~/nltk_data')
|
30 |
if not os.path.exists(os.path.join(nltk_data_path, 'corpora/cmudict.zip')):
|
31 |
nltk.download('cmudict', download_dir=nltk_data_path)
|
32 |
-
|
33 |
if not os.path.exists(os.path.join(nltk_data_path, 'taggers/averaged_perceptron_tagger.zip')):
|
34 |
nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_path)
|
35 |
|
|
|
36 |
from parrots import TextToSpeech
|
37 |
|
38 |
-
#
|
39 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
logger.info(f"device: {device}")
|
41 |
half = True if device == "cuda" else False
|
42 |
|
43 |
-
#
|
44 |
-
m = TextToSpeech(
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
#
|
47 |
def get_text_hash(text: str):
|
48 |
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
49 |
|
@@ -54,7 +66,7 @@ def do_tts_wav_predict(text: str, output_path: str = None):
|
|
54 |
m.predict(text, text_language="auto", output_path=output_path)
|
55 |
return output_path
|
56 |
|
57 |
-
#
|
58 |
with gr.Blocks(title="TTS WebUI") as app:
|
59 |
gr.Markdown("""
|
60 |
# 線上語音合成 (TWMAN)
|
@@ -81,22 +93,25 @@ with gr.Blocks(title="TTS WebUI") as app:
|
|
81 |
- [用PPOCRLabel來幫PaddleOCR做OCR的微調和標註](https://blog.twman.org/2023/07/wsl.html)
|
82 |
- [基於機器閱讀理解和指令微調的統一信息抽取框架之診斷書醫囑資訊擷取分析](https://blog.twman.org/2023/07/HugIE.html)
|
83 |
""")
|
84 |
-
|
85 |
-
# 設定語音合成輸入與按鈕
|
86 |
with gr.Group():
|
87 |
-
gr.Markdown("
|
88 |
with gr.Row():
|
89 |
-
text = gr.Textbox(
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
# 設定按鈕點擊事件
|
94 |
inference_button.click(
|
95 |
do_tts_wav_predict,
|
96 |
[text],
|
97 |
[output],
|
98 |
)
|
99 |
|
100 |
-
# 啟動 Gradio
|
101 |
app.queue(max_size=10)
|
102 |
-
app.launch(share=True, inbrowser=True)
|
|
|
3 |
@author:XuMing([email protected])
|
4 |
@description: Re-train by TWMAN
|
5 |
"""
|
6 |
+
|
7 |
import hashlib
|
8 |
import os
|
9 |
import ssl
|
10 |
+
import subprocess
|
11 |
|
12 |
import gradio as gr
|
13 |
import torch
|
14 |
from loguru import logger
|
15 |
+
import nltk
|
16 |
|
17 |
+
# 設定 HTTPS context 避免證書錯誤
|
18 |
ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
19 |
|
20 |
+
# 🛠 安裝相容的 LangSegment(v0.1.5)
|
21 |
+
def install_compatible_LangSegment():
|
22 |
try:
|
23 |
+
subprocess.check_call([
|
24 |
+
os.sys.executable, "-m", "pip",
|
25 |
+
"install", "LangSegment==0.1.5",
|
26 |
+
"-i", "https://pypi.org/simple",
|
27 |
+
"--force-reinstall"
|
28 |
+
])
|
29 |
+
print("✅ LangSegment 降級成功")
|
30 |
except subprocess.CalledProcessError:
|
31 |
+
print("❌ LangSegment 降級失敗")
|
32 |
|
33 |
+
install_compatible_LangSegment()
|
|
|
34 |
|
35 |
+
# 🧠 下載 NLTK 所需資源
|
36 |
nltk_data_path = os.path.expanduser('~/nltk_data')
|
37 |
if not os.path.exists(os.path.join(nltk_data_path, 'corpora/cmudict.zip')):
|
38 |
nltk.download('cmudict', download_dir=nltk_data_path)
|
|
|
39 |
if not os.path.exists(os.path.join(nltk_data_path, 'taggers/averaged_perceptron_tagger.zip')):
|
40 |
nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_path)
|
41 |
|
42 |
+
# 📦 匯入 parrots
|
43 |
from parrots import TextToSpeech
|
44 |
|
45 |
+
# 設定裝置與精度
|
46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
logger.info(f"device: {device}")
|
48 |
half = True if device == "cuda" else False
|
49 |
|
50 |
+
# 初始化 TTS 模型
|
51 |
+
m = TextToSpeech(
|
52 |
+
speaker_model_path="DeepLearning101/GPT-SoVITS_TWMAN",
|
53 |
+
speaker_name="TWMAN",
|
54 |
+
device=device,
|
55 |
+
half=half
|
56 |
+
)
|
57 |
|
58 |
+
# 🔊 音訊生成邏輯
|
59 |
def get_text_hash(text: str):
|
60 |
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
61 |
|
|
|
66 |
m.predict(text, text_language="auto", output_path=output_path)
|
67 |
return output_path
|
68 |
|
69 |
+
# 🌐 Gradio WebUI 設定
|
70 |
with gr.Blocks(title="TTS WebUI") as app:
|
71 |
gr.Markdown("""
|
72 |
# 線上語音合成 (TWMAN)
|
|
|
93 |
- [用PPOCRLabel來幫PaddleOCR做OCR的微調和標註](https://blog.twman.org/2023/07/wsl.html)
|
94 |
- [基於機器閱讀理解和指令微調的統一信息抽取框架之診斷書醫囑資訊擷取分析](https://blog.twman.org/2023/07/HugIE.html)
|
95 |
""")
|
96 |
+
|
|
|
97 |
with gr.Group():
|
98 |
+
gr.Markdown("🔤 請輸入要進行語音合成的文字:")
|
99 |
with gr.Row():
|
100 |
+
text = gr.Textbox(
|
101 |
+
label="輸入文字(建議 100 字內)",
|
102 |
+
value="床前明月光,疑是地上霜。舉頭望明月,低頭思故鄉。",
|
103 |
+
placeholder="請輸入文字...",
|
104 |
+
lines=3
|
105 |
+
)
|
106 |
+
inference_button = gr.Button("🎤 語音合成", variant="primary")
|
107 |
+
output = gr.Audio(label="🔊 合成的語音")
|
108 |
|
|
|
109 |
inference_button.click(
|
110 |
do_tts_wav_predict,
|
111 |
[text],
|
112 |
[output],
|
113 |
)
|
114 |
|
115 |
+
# 啟動 Gradio App
|
116 |
app.queue(max_size=10)
|
117 |
+
app.launch(share=True, inbrowser=True)
|