Update xtts.py
Browse files
xtts.py
CHANGED
@@ -1,27 +1,12 @@
|
|
1 |
-
import re,
|
2 |
-
import tempfile, subprocess
|
3 |
import requests
|
4 |
import torch
|
5 |
import traceback
|
6 |
-
import numpy as np
|
7 |
-
import scipy
|
8 |
from TTS.api import TTS
|
9 |
|
10 |
-
import torch
|
11 |
|
|
|
12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
-
|
14 |
-
from qili import upload, upload_bytes
|
15 |
-
# def upload_bytes(bytes, ext=".wav"):
|
16 |
-
# return bytes
|
17 |
-
# def upload(file):
|
18 |
-
# return file
|
19 |
-
|
20 |
-
# if __name__ == "__main__":
|
21 |
-
# app = Flask(__name__)
|
22 |
-
# else:
|
23 |
-
# app = Blueprint("xtts", __name__)
|
24 |
-
|
25 |
tts=None
|
26 |
model=None
|
27 |
|
@@ -34,6 +19,13 @@ if not os.path.exists(sample_root):
|
|
34 |
default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
|
35 |
ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def predict(text, sample=None, language="zh"):
|
38 |
get_tts()
|
39 |
global tts
|
@@ -41,29 +33,14 @@ def predict(text, sample=None, language="zh"):
|
|
41 |
try:
|
42 |
text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
|
43 |
output=tempfile.mktemp(suffix=".wav")
|
44 |
-
|
45 |
text,
|
46 |
language=language if language is not None else "zh",
|
47 |
speaker_wav=sample if sample is not None else default_sample[0],
|
48 |
file_path=output
|
49 |
)
|
50 |
-
|
51 |
output=to_mp3(output)
|
52 |
-
|
53 |
return upload(output)[0]
|
54 |
-
|
55 |
-
with io.BytesIO() as wav_buffer:
|
56 |
-
if torch.is_tensor(wav):
|
57 |
-
wav = wav.cpu().numpy()
|
58 |
-
if isinstance(wav, list):
|
59 |
-
wav = np.array(wav)
|
60 |
-
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
61 |
-
wav_norm = wav_norm.astype(np.int16)
|
62 |
-
scipy.io.wavfile.write(wav_buffer, tts.synthesizer.output_sample_rate, wav_norm)
|
63 |
-
wav_bytes = wav_buffer.getvalue()
|
64 |
-
url= upload_bytes(wav_bytes, ext=".wav")
|
65 |
-
logging.debug(f'wav is at {url}')
|
66 |
-
return url
|
67 |
except Exception as e:
|
68 |
traceback.print_exc()
|
69 |
return str(e)
|
@@ -131,6 +108,8 @@ def trim_sample_audio(speaker_wav):
|
|
131 |
capture_output=False,
|
132 |
text=True,
|
133 |
check=True,
|
|
|
|
|
134 |
)
|
135 |
return out_filename
|
136 |
except:
|
@@ -147,12 +126,18 @@ def to_mp3(wav):
|
|
147 |
capture_output=False,
|
148 |
text=True,
|
149 |
check=True,
|
|
|
|
|
150 |
)
|
151 |
return mp3
|
152 |
except:
|
153 |
traceback.print_exc()
|
154 |
return wav
|
155 |
-
|
|
|
|
|
|
|
|
|
156 |
|
157 |
from flask import Flask, request
|
158 |
app = Flask(__name__)
|
@@ -161,17 +146,11 @@ def convert():
|
|
161 |
text = request.args.get('text')
|
162 |
sample = request.args.get('sample')
|
163 |
language = request.args.get('language')
|
164 |
-
# from fastapi import FastAPI as App, Query
|
165 |
-
# app=App()
|
166 |
-
# @app.get("/url")
|
167 |
-
# def convert(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
|
168 |
if text is None:
|
169 |
return 'text is missing', 400
|
170 |
|
171 |
return predict(text, sample, language)
|
172 |
|
173 |
-
# @app.get("/play")
|
174 |
-
# def play(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
|
175 |
@app.route("/tts/play")
|
176 |
def tts_play():
|
177 |
url=convert()
|
@@ -183,17 +162,16 @@ def get_tts():
|
|
183 |
global tts
|
184 |
global model
|
185 |
if tts is None:
|
186 |
-
|
187 |
-
model_path
|
188 |
-
|
189 |
-
vocoder_config_path=f'{model_dir}/vocab.json'
|
190 |
model_name="tts_models/multilingual/multi-dataset/xtts_v2"
|
191 |
logging.info(f"loading model {model_name} ...")
|
192 |
tts = TTS(
|
193 |
-
|
194 |
-
model_path=model_path,
|
195 |
-
config_path=config_path,
|
196 |
-
vocoder_config_path=vocoder_config_path,
|
197 |
progress_bar=True
|
198 |
)
|
199 |
model=tts.synthesizer.tts_model
|
|
|
1 |
+
import re, os, logging, tempfile, subprocess
|
|
|
2 |
import requests
|
3 |
import torch
|
4 |
import traceback
|
|
|
|
|
5 |
from TTS.api import TTS
|
6 |
|
|
|
7 |
|
8 |
+
bLOCAL=not bool(os.environ.get('api'))
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
tts=None
|
11 |
model=None
|
12 |
|
|
|
19 |
default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
|
20 |
ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
|
21 |
|
22 |
+
if bLOCAL:
|
23 |
+
def upload(file):
|
24 |
+
return file
|
25 |
+
else:
|
26 |
+
from qili import upload
|
27 |
+
|
28 |
+
|
29 |
def predict(text, sample=None, language="zh"):
|
30 |
get_tts()
|
31 |
global tts
|
|
|
33 |
try:
|
34 |
text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
|
35 |
output=tempfile.mktemp(suffix=".wav")
|
36 |
+
tts.tts_to_file(
|
37 |
text,
|
38 |
language=language if language is not None else "zh",
|
39 |
speaker_wav=sample if sample is not None else default_sample[0],
|
40 |
file_path=output
|
41 |
)
|
|
|
42 |
output=to_mp3(output)
|
|
|
43 |
return upload(output)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
traceback.print_exc()
|
46 |
return str(e)
|
|
|
108 |
capture_output=False,
|
109 |
text=True,
|
110 |
check=True,
|
111 |
+
stdout=subprocess.DEVNULL,
|
112 |
+
stderr=subprocess.DEVNULL,
|
113 |
)
|
114 |
return out_filename
|
115 |
except:
|
|
|
126 |
capture_output=False,
|
127 |
text=True,
|
128 |
check=True,
|
129 |
+
stdout=subprocess.DEVNULL,
|
130 |
+
stderr=subprocess.DEVNULL,
|
131 |
)
|
132 |
return mp3
|
133 |
except:
|
134 |
traceback.print_exc()
|
135 |
return wav
|
136 |
+
# if __name__ == "__main__":
|
137 |
+
# app = Flask(__name__)
|
138 |
+
# else:
|
139 |
+
# app = Blueprint("xtts", __name__)
|
140 |
+
|
141 |
|
142 |
from flask import Flask, request
|
143 |
app = Flask(__name__)
|
|
|
146 |
text = request.args.get('text')
|
147 |
sample = request.args.get('sample')
|
148 |
language = request.args.get('language')
|
|
|
|
|
|
|
|
|
149 |
if text is None:
|
150 |
return 'text is missing', 400
|
151 |
|
152 |
return predict(text, sample, language)
|
153 |
|
|
|
|
|
154 |
@app.route("/tts/play")
|
155 |
def tts_play():
|
156 |
url=convert()
|
|
|
162 |
global tts
|
163 |
global model
|
164 |
if tts is None:
|
165 |
+
model_path=os.environ.get("MODEL_DIR")
|
166 |
+
config_path=f'{model_path}/config.json'
|
167 |
+
vocoder_config_path=f'{model_path}/vocab.json'
|
|
|
168 |
model_name="tts_models/multilingual/multi-dataset/xtts_v2"
|
169 |
logging.info(f"loading model {model_name} ...")
|
170 |
tts = TTS(
|
171 |
+
model_name if bLOCAL else None,
|
172 |
+
model_path=model_path if not bLOCAL else None,
|
173 |
+
config_path=config_path if not bLOCAL else None,
|
174 |
+
vocoder_config_path=vocoder_config_path if not bLOCAL else None,
|
175 |
progress_bar=True
|
176 |
)
|
177 |
model=tts.synthesizer.tts_model
|