Yilin0601 commited on
Commit
e947b77
·
verified ·
1 Parent(s): 799659c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -9,11 +9,11 @@ import os
9
  from transformers import pipeline, VitsModel, AutoTokenizer
10
  from datasets import load_dataset
11
 
12
- # For Coqui TTS (XTTS-v2)
13
  try:
14
- from TTS.api import TTS as CoquiTTS
15
  except ImportError:
16
- raise ImportError("Please install Coqui TTS via `pip install TTS`.")
17
 
18
  # ------------------------------------------------------
19
  # 1. ASR Pipeline (English) using Wav2Vec2
@@ -51,7 +51,7 @@ translation_tasks = {
51
  # ------------------------------------------------------
52
  # 3. TTS Configuration
53
  # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
- # - Coqui XTTS-v2 for: Chinese and Japanese
55
  # ------------------------------------------------------
56
  tts_config = {
57
  "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
@@ -60,14 +60,8 @@ tts_config = {
60
  "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
  "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
  "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
- "Chinese": {"type": "coqui"},
64
- "Japanese": {"type": "coqui"}
65
- }
66
-
67
- # For Coqui, we map our languages to language codes expected by the model.
68
- coqui_lang_map = {
69
- "Chinese": "zh",
70
- "Japanese": "ja"
71
  }
72
 
73
  # ------------------------------------------------------
@@ -75,7 +69,7 @@ coqui_lang_map = {
75
  # ------------------------------------------------------
76
  translator_cache = {}
77
  mms_tts_cache = {} # For MMS (VITS-based) TTS models
78
- coqui_tts_cache = None # Single instance for Coqui XTTS-v2
79
 
80
  # ------------------------------------------------------
81
  # 5. Translator Helper
@@ -116,31 +110,31 @@ def run_mms_tts(text, lang):
116
  return sample_rate, waveform
117
 
118
  # ------------------------------------------------------
119
- # 7. Coqui TTS Helper for Chinese and Japanese
120
  # ------------------------------------------------------
121
- def load_coqui_tts():
122
- global coqui_tts_cache
123
- if coqui_tts_cache is not None:
124
- return coqui_tts_cache
125
- try:
126
- # Set gpu=True if a GPU is available.
127
- coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
128
- except Exception as e:
129
- raise RuntimeError(f"Failed to load Coqui XTTS-v2 TTS: {e}")
130
- return coqui_tts_cache
131
-
132
- def run_coqui_tts(text, lang):
133
- coqui_tts = load_coqui_tts()
134
- lang_code = coqui_lang_map[lang] # "zh" for Chinese or "ja" for Japanese
135
- # Write the output to a temporary file and then read it back.
 
 
 
 
136
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
137
  tmp_name = tmp.name
138
  try:
139
- coqui_tts.tts_to_file(
140
- text=text,
141
- file_path=tmp_name,
142
- language=lang_code # using default voice; for cloning, add speaker_wav parameter
143
- )
144
  data, sr = sf.read(tmp_name)
145
  finally:
146
  if os.path.exists(tmp_name):
@@ -153,8 +147,8 @@ def run_coqui_tts(text, lang):
153
  def predict(audio, text, target_language):
154
  """
155
  1. Obtain English text (via ASR if audio provided, else text).
156
- 2. Translate English text to target_language.
157
- 3. Generate TTS audio using either MMS TTS (VITS) or Coqui XTTS-v2.
158
  """
159
  # Step 1: Get English text.
160
  if text.strip():
@@ -186,8 +180,8 @@ def predict(audio, text, target_language):
186
  tts_type = tts_config[target_language]["type"]
187
  if tts_type == "mms":
188
  sr, waveform = run_mms_tts(translated_text, target_language)
189
- elif tts_type == "coqui":
190
- sr, waveform = run_coqui_tts(translated_text, target_language)
191
  else:
192
  raise RuntimeError("Unknown TTS type for target language.")
193
  except Exception as e:
@@ -218,12 +212,14 @@ iface = gr.Interface(
218
  description=(
219
  "This app performs the following steps:\n"
220
  "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
221
- "2. Translates the English text to the target language using Helsinki-NLP models.\n"
222
- "3. Provides Synthetic speech:\n"
223
- "For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean."
 
 
224
  ),
225
  allow_flagging="never"
226
  )
227
 
228
  if __name__ == "__main__":
229
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
9
  from transformers import pipeline, VitsModel, AutoTokenizer
10
  from datasets import load_dataset
11
 
12
+ # For MeloTTS (Chinese and Japanese)
13
  try:
14
+ from melo.api import TTS as MeloTTS
15
  except ImportError:
16
+ raise ImportError("Please install the MeloTTS package (e.g., pip install myshell-ai/MeloTTS-Chinese)")
17
 
18
  # ------------------------------------------------------
19
  # 1. ASR Pipeline (English) using Wav2Vec2
 
51
  # ------------------------------------------------------
52
  # 3. TTS Configuration
53
  # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
+ # - MeloTTS for: Chinese and Japanese
55
  # ------------------------------------------------------
56
  tts_config = {
57
  "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
 
60
  "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
  "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
  "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
+ "Chinese": {"type": "melo"},
64
+ "Japanese": {"type": "melo"}
 
 
 
 
 
 
65
  }
66
 
67
  # ------------------------------------------------------
 
69
  # ------------------------------------------------------
70
  translator_cache = {}
71
  mms_tts_cache = {} # For MMS (VITS-based) TTS models
72
+ melo_tts_cache = {} # For MeloTTS models (Chinese/Japanese)
73
 
74
  # ------------------------------------------------------
75
  # 5. Translator Helper
 
110
  return sample_rate, waveform
111
 
112
  # ------------------------------------------------------
113
+ # 7. MeloTTS Helper for Chinese and Japanese
114
  # ------------------------------------------------------
115
+ def run_melo_tts(text, lang):
116
+ """
117
+ Uses the myshell-ai MeloTTS model.
118
+ For Chinese, use language parameter 'ZH'; for Japanese, use 'JP'.
119
+ """
120
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
+ lang_param = 'ZH' if lang == "Chinese" else 'JP'
122
+ if lang not in melo_tts_cache:
123
+ try:
124
+ model = MeloTTS(language=lang_param, device=device)
125
+ melo_tts_cache[lang] = model
126
+ except Exception as e:
127
+ raise RuntimeError(f"Failed to load MeloTTS model for {lang}: {e}")
128
+ else:
129
+ model = melo_tts_cache[lang]
130
+ speaker_ids = model.hps.data.spk2id
131
+ # Assume the speaker key is the same as lang_param
132
+ speaker_key = lang_param
133
+ speed = 1.0
134
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
135
  tmp_name = tmp.name
136
  try:
137
+ model.tts_to_file(text, speaker_ids[speaker_key], tmp_name, speed=speed)
 
 
 
 
138
  data, sr = sf.read(tmp_name)
139
  finally:
140
  if os.path.exists(tmp_name):
 
147
  def predict(audio, text, target_language):
148
  """
149
  1. Obtain English text (via ASR if audio provided, else text).
150
+ 2. Translate the English text to target_language.
151
+ 3. Generate TTS audio using either MMS TTS (VITS) or MeloTTS.
152
  """
153
  # Step 1: Get English text.
154
  if text.strip():
 
180
  tts_type = tts_config[target_language]["type"]
181
  if tts_type == "mms":
182
  sr, waveform = run_mms_tts(translated_text, target_language)
183
+ elif tts_type == "melo":
184
+ sr, waveform = run_melo_tts(translated_text, target_language)
185
  else:
186
  raise RuntimeError("Unknown TTS type for target language.")
187
  except Exception as e:
 
212
  description=(
213
  "This app performs the following steps:\n"
214
  "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
215
+ "2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
216
+ "3. Synthesizes speech:\n"
217
+ " - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
218
+ " - For Chinese and Japanese: uses myshell-ai MeloTTS models.\n"
219
+ "\nSelect your target language from the dropdown."
220
  ),
221
  allow_flagging="never"
222
  )
223
 
224
  if __name__ == "__main__":
225
+ iface.launch(server_name="0.0.0.0", server_port=7860)