lalalic commited on
Commit
1056749
·
verified ·
1 Parent(s): 49be9bd

Update xtts.py

Browse files
Files changed (1) hide show
  1. xtts.py +26 -48
xtts.py CHANGED
@@ -1,27 +1,12 @@
1
- import re, io, os, stat, logging
2
- import tempfile, subprocess
3
  import requests
4
  import torch
5
  import traceback
6
- import numpy as np
7
- import scipy
8
  from TTS.api import TTS
9
 
10
- import torch
11
 
 
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
-
14
- from qili import upload, upload_bytes
15
- # def upload_bytes(bytes, ext=".wav"):
16
- # return bytes
17
- # def upload(file):
18
- # return file
19
-
20
- # if __name__ == "__main__":
21
- # app = Flask(__name__)
22
- # else:
23
- # app = Blueprint("xtts", __name__)
24
-
25
  tts=None
26
  model=None
27
 
@@ -34,6 +19,13 @@ if not os.path.exists(sample_root):
34
  default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
35
  ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
36
 
 
 
 
 
 
 
 
37
  def predict(text, sample=None, language="zh"):
38
  get_tts()
39
  global tts
@@ -41,29 +33,14 @@ def predict(text, sample=None, language="zh"):
41
  try:
42
  text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
43
  output=tempfile.mktemp(suffix=".wav")
44
- wav = tts.tts_to_file(
45
  text,
46
  language=language if language is not None else "zh",
47
  speaker_wav=sample if sample is not None else default_sample[0],
48
  file_path=output
49
  )
50
-
51
  output=to_mp3(output)
52
-
53
  return upload(output)[0]
54
-
55
- with io.BytesIO() as wav_buffer:
56
- if torch.is_tensor(wav):
57
- wav = wav.cpu().numpy()
58
- if isinstance(wav, list):
59
- wav = np.array(wav)
60
- wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
61
- wav_norm = wav_norm.astype(np.int16)
62
- scipy.io.wavfile.write(wav_buffer, tts.synthesizer.output_sample_rate, wav_norm)
63
- wav_bytes = wav_buffer.getvalue()
64
- url= upload_bytes(wav_bytes, ext=".wav")
65
- logging.debug(f'wav is at {url}')
66
- return url
67
  except Exception as e:
68
  traceback.print_exc()
69
  return str(e)
@@ -131,6 +108,8 @@ def trim_sample_audio(speaker_wav):
131
  capture_output=False,
132
  text=True,
133
  check=True,
 
 
134
  )
135
  return out_filename
136
  except:
@@ -147,12 +126,18 @@ def to_mp3(wav):
147
  capture_output=False,
148
  text=True,
149
  check=True,
 
 
150
  )
151
  return mp3
152
  except:
153
  traceback.print_exc()
154
  return wav
155
-
 
 
 
 
156
 
157
  from flask import Flask, request
158
  app = Flask(__name__)
@@ -161,17 +146,11 @@ def convert():
161
  text = request.args.get('text')
162
  sample = request.args.get('sample')
163
  language = request.args.get('language')
164
- # from fastapi import FastAPI as App, Query
165
- # app=App()
166
- # @app.get("/url")
167
- # def convert(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
168
  if text is None:
169
  return 'text is missing', 400
170
 
171
  return predict(text, sample, language)
172
 
173
- # @app.get("/play")
174
- # def play(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
175
  @app.route("/tts/play")
176
  def tts_play():
177
  url=convert()
@@ -183,17 +162,16 @@ def get_tts():
183
  global tts
184
  global model
185
  if tts is None:
186
- model_dir=os.environ.get("MODEL_DIR")
187
- model_path=model_dir
188
- config_path=f'{model_dir}/config.json'
189
- vocoder_config_path=f'{model_dir}/vocab.json'
190
  model_name="tts_models/multilingual/multi-dataset/xtts_v2"
191
  logging.info(f"loading model {model_name} ...")
192
  tts = TTS(
193
- # model_name,
194
- model_path=model_path,
195
- config_path=config_path,
196
- vocoder_config_path=vocoder_config_path,
197
  progress_bar=True
198
  )
199
  model=tts.synthesizer.tts_model
 
1
+ import re, os, logging, tempfile, subprocess
 
2
  import requests
3
  import torch
4
  import traceback
 
 
5
  from TTS.api import TTS
6
 
 
7
 
8
+ bLOCAL=not bool(os.environ.get('api'))
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
 
 
 
 
 
10
  tts=None
11
  model=None
12
 
 
19
  default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
20
  ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
21
 
22
+ if bLOCAL:
23
+ def upload(file):
24
+ return file
25
+ else:
26
+ from qili import upload
27
+
28
+
29
  def predict(text, sample=None, language="zh"):
30
  get_tts()
31
  global tts
 
33
  try:
34
  text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
35
  output=tempfile.mktemp(suffix=".wav")
36
+ tts.tts_to_file(
37
  text,
38
  language=language if language is not None else "zh",
39
  speaker_wav=sample if sample is not None else default_sample[0],
40
  file_path=output
41
  )
 
42
  output=to_mp3(output)
 
43
  return upload(output)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  except Exception as e:
45
  traceback.print_exc()
46
  return str(e)
 
108
  capture_output=False,
109
  text=True,
110
  check=True,
111
+ stdout=subprocess.DEVNULL,
112
+ stderr=subprocess.DEVNULL,
113
  )
114
  return out_filename
115
  except:
 
126
  capture_output=False,
127
  text=True,
128
  check=True,
129
+ stdout=subprocess.DEVNULL,
130
+ stderr=subprocess.DEVNULL,
131
  )
132
  return mp3
133
  except:
134
  traceback.print_exc()
135
  return wav
136
+ # if __name__ == "__main__":
137
+ # app = Flask(__name__)
138
+ # else:
139
+ # app = Blueprint("xtts", __name__)
140
+
141
 
142
  from flask import Flask, request
143
  app = Flask(__name__)
 
146
  text = request.args.get('text')
147
  sample = request.args.get('sample')
148
  language = request.args.get('language')
 
 
 
 
149
  if text is None:
150
  return 'text is missing', 400
151
 
152
  return predict(text, sample, language)
153
 
 
 
154
  @app.route("/tts/play")
155
  def tts_play():
156
  url=convert()
 
162
  global tts
163
  global model
164
  if tts is None:
165
+ model_path=os.environ.get("MODEL_DIR")
166
+ config_path=f'{model_path}/config.json'
167
+ vocoder_config_path=f'{model_path}/vocab.json'
 
168
  model_name="tts_models/multilingual/multi-dataset/xtts_v2"
169
  logging.info(f"loading model {model_name} ...")
170
  tts = TTS(
171
+ model_name if bLOCAL else None,
172
+ model_path=model_path if not bLOCAL else None,
173
+ config_path=config_path if not bLOCAL else None,
174
+ vocoder_config_path=vocoder_config_path if not bLOCAL else None,
175
  progress_bar=True
176
  )
177
  model=tts.synthesizer.tts_model