Yilin0601 commited on
Commit
178dac1
·
verified ·
1 Parent(s): 2334caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -84
app.py CHANGED
@@ -2,12 +2,24 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import librosa
5
- import soundfile as sf # likely needed by the pipeline or local saving
6
- from transformers import pipeline, VitsModel, AutoTokenizer
7
- from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # ------------------------------------------------------
10
- # 1. ASR Pipeline (English) - Wav2Vec2
11
  # ------------------------------------------------------
12
  asr = pipeline(
13
  "automatic-speech-recognition",
@@ -30,28 +42,32 @@ translation_tasks = {
30
  }
31
 
32
  # ------------------------------------------------------
33
- # 3. TTS Configuration
34
- # - Spanish: VITS-based MMS TTS
35
- # - Chinese & Japanese: Microsoft SpeechT5
36
  # ------------------------------------------------------
37
- # We'll store them as keys for convenience
38
- SPANISH_KEY = "Spanish"
39
- CHINESE_KEY = "Chinese"
40
- JAPANESE_KEY = "Japanese"
41
 
42
- # VITS config for Spanish only
43
  mms_spanish_config = {
44
- "model_id": "facebook/mms-tts-spa",
45
  "architecture": "vits"
46
  }
47
 
 
 
 
 
 
 
48
  # ------------------------------------------------------
49
- # 4. Create TTS Pipelines / Models Once (Caching)
50
  # ------------------------------------------------------
51
  translator_cache = {}
52
- vits_model_cache = None # for Spanish
53
- speech_t5_pipeline_cache = None # for Chinese/Japanese
54
- speech_t5_speaker_embedding = None
55
 
56
  def get_translator(lang):
57
  """
@@ -65,91 +81,99 @@ def get_translator(lang):
65
  translator_cache[lang] = translator
66
  return translator
67
 
 
 
 
68
  def load_spanish_vits():
69
  """
70
- Load and cache the Spanish VITS model + tokenizer (facebook/mms-tts-spa).
71
  """
72
- global vits_model_cache
73
- if vits_model_cache is not None:
74
- return vits_model_cache
75
 
76
  try:
77
- model_id = mms_spanish_config["model_id"]
78
- model = VitsModel.from_pretrained(model_id)
79
- tokenizer = AutoTokenizer.from_pretrained(model_id)
80
- vits_model_cache = (model, tokenizer)
81
  except Exception as e:
82
  raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
83
 
84
- return vits_model_cache
85
-
86
- def load_speech_t5_pipeline():
87
- """
88
- Load and cache the Microsoft SpeechT5 text-to-speech pipeline
89
- and a default speaker embedding.
90
- """
91
- global speech_t5_pipeline_cache, speech_t5_speaker_embedding
92
- if speech_t5_pipeline_cache is not None and speech_t5_speaker_embedding is not None:
93
- return speech_t5_pipeline_cache, speech_t5_speaker_embedding
94
-
95
- try:
96
- # Create the pipeline
97
- # The pipeline is named "text-to-speech" in Transformers >= 4.29
98
- t5_pipe = pipeline("text-to-speech", model="microsoft/speecht5_tts")
99
- except Exception as e:
100
- raise RuntimeError(f"Failed to load Microsoft SpeechT5 pipeline: {e}")
101
-
102
- # Load a default speaker embedding
103
- try:
104
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
105
- # Just pick an arbitrary index for speaker embedding
106
- speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
107
- except Exception as e:
108
- raise RuntimeError(f"Failed to load default speaker embedding: {e}")
109
-
110
- speech_t5_pipeline_cache = t5_pipe
111
- speech_t5_speaker_embedding = speaker_embedding
112
- return t5_pipe, speaker_embedding
113
 
114
- # ------------------------------------------------------
115
- # 5. TTS Inference Helpers
116
- # ------------------------------------------------------
117
- def run_vits_inference(text):
118
  """
119
- For Spanish TTS using MMS (facebook/mms-tts-spa).
 
120
  """
121
  model, tokenizer = load_spanish_vits()
122
  inputs = tokenizer(text, return_tensors="pt")
123
  with torch.no_grad():
124
  output = model(**inputs)
125
  if not hasattr(output, "waveform"):
126
- raise RuntimeError("VITS output does not contain 'waveform'.")
127
  waveform = output.waveform.squeeze().cpu().numpy()
128
  sample_rate = 16000
129
  return sample_rate, waveform
130
 
131
- def run_speecht5_inference(text):
 
 
 
132
  """
133
- For Chinese & Japanese TTS using Microsoft SpeechT5 pipeline.
134
  """
135
- t5_pipe, speaker_embedding = load_speech_t5_pipeline()
136
- # The pipeline returns a dict with 'audio' (numpy) and 'sampling_rate'
137
- result = t5_pipe(
138
- text,
139
- forward_params={"speaker_embeddings": speaker_embedding}
140
- )
141
- waveform = result["audio"]
142
- sample_rate = result["sampling_rate"]
143
- return sample_rate, waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # ------------------------------------------------------
146
- # 6. Main Prediction Function
147
  # ------------------------------------------------------
148
  def predict(audio, text, target_language):
149
  """
150
  1. Get English text (ASR if audio provided, else text).
151
  2. Translate to target_language.
152
- 3. TTS with the chosen approach (VITS for Spanish, SpeechT5 for Chinese/Japanese).
 
 
153
  """
154
  # Step 1: English text
155
  if text.strip():
@@ -185,25 +209,25 @@ def predict(audio, text, target_language):
185
 
186
  # Step 3: TTS
187
  try:
188
- if target_language == SPANISH_KEY:
189
- sr, waveform = run_vits_inference(translated_text)
190
  else:
191
- # Chinese or Japanese -> SpeechT5
192
- sr, waveform = run_speecht5_inference(translated_text)
193
  except Exception as e:
194
  return english_text, translated_text, f"TTS error: {e}"
195
 
196
  return english_text, translated_text, (sr, waveform)
197
 
198
  # ------------------------------------------------------
199
- # 7. Gradio Interface
200
  # ------------------------------------------------------
201
  iface = gr.Interface(
202
  fn=predict,
203
  inputs=[
204
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
205
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
206
- gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language")
207
  ],
208
  outputs=[
209
  gr.Textbox(label="English Transcription"),
@@ -212,14 +236,16 @@ iface = gr.Interface(
212
  ],
213
  title="Multimodal Language Learning Aid",
214
  description=(
215
- "1. Transcribes English speech using Wav2Vec2-960h (or takes English text).\n"
216
- "2. Translates to Spanish, Chinese, or Japanese.\n"
217
- "3. Provides synthetic speech:\n"
218
  " - Spanish -> facebook/mms-tts-spa (VITS)\n"
219
- " - Chinese & Japanese -> microsoft/speecht5_tts (SpeechT5)\n\n"
 
 
220
  ),
221
  allow_flagging="never"
222
  )
223
 
224
  if __name__ == "__main__":
225
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import torch
3
  import numpy as np
4
  import librosa
5
+ import soundfile as sf
6
+ import tempfile
7
+ import os
8
+
9
+ from transformers import (
10
+ pipeline,
11
+ VitsModel,
12
+ AutoTokenizer
13
+ )
14
+
15
+ # For Coqui TTS
16
+ try:
17
+ from TTS.api import TTS as CoquiTTS
18
+ except ImportError:
19
+ raise ImportError("Please install Coqui TTS via `pip install TTS`.")
20
 
21
  # ------------------------------------------------------
22
+ # 1. ASR Pipeline (English) using Wav2Vec2
23
  # ------------------------------------------------------
24
  asr = pipeline(
25
  "automatic-speech-recognition",
 
42
  }
43
 
44
  # ------------------------------------------------------
45
+ # 3. TTS Config:
46
+ # - Spanish: MMS TTS (facebook/mms-tts-spa)
47
+ # - Chinese, Japanese: Coqui XTTS-v2 (tts_models/multilingual/multi-dataset/xtts_v2)
48
  # ------------------------------------------------------
49
+ SPANISH = "Spanish"
50
+ CHINESE = "Chinese"
51
+ JAPANESE = "Japanese"
 
52
 
53
+ # For Spanish (MMS)
54
  mms_spanish_config = {
55
+ "model_id": "facebook/mms-tts-spa",
56
  "architecture": "vits"
57
  }
58
 
59
+ # We'll map Chinese/Japanese to Coqui language codes
60
+ coqui_lang_map = {
61
+ CHINESE: "zh",
62
+ JAPANESE: "ja"
63
+ }
64
+
65
  # ------------------------------------------------------
66
+ # 4. Global Caches
67
  # ------------------------------------------------------
68
  translator_cache = {}
69
+ spanish_vits_cache = None
70
+ coqui_tts_cache = None
 
71
 
72
  def get_translator(lang):
73
  """
 
81
  translator_cache[lang] = translator
82
  return translator
83
 
84
+ # ------------------------------------------------------
85
+ # 5. Spanish TTS: MMS (VITS)
86
+ # ------------------------------------------------------
87
  def load_spanish_vits():
88
  """
89
+ Load and cache the Spanish MMS TTS model (VITS).
90
  """
91
+ global spanish_vits_cache
92
+ if spanish_vits_cache is not None:
93
+ return spanish_vits_cache
94
 
95
  try:
96
+ model = VitsModel.from_pretrained(mms_spanish_config["model_id"])
97
+ tokenizer = AutoTokenizer.from_pretrained(mms_spanish_config["model_id"])
98
+ spanish_vits_cache = (model, tokenizer)
 
99
  except Exception as e:
100
  raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
101
 
102
+ return spanish_vits_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ def run_spanish_tts(text):
 
 
 
105
  """
106
+ Run MMS TTS (VITS) for Spanish text.
107
+ Returns (sample_rate, waveform).
108
  """
109
  model, tokenizer = load_spanish_vits()
110
  inputs = tokenizer(text, return_tensors="pt")
111
  with torch.no_grad():
112
  output = model(**inputs)
113
  if not hasattr(output, "waveform"):
114
+ raise RuntimeError("Spanish TTS model output does not contain 'waveform'.")
115
  waveform = output.waveform.squeeze().cpu().numpy()
116
  sample_rate = 16000
117
  return sample_rate, waveform
118
 
119
+ # ------------------------------------------------------
120
+ # 6. Chinese/Japanese TTS: Coqui XTTS-v2
121
+ # ------------------------------------------------------
122
+ def load_coqui_tts():
123
  """
124
+ Load and cache the Coqui XTTS-v2 model (multilingual).
125
  """
126
+ global coqui_tts_cache
127
+ if coqui_tts_cache is not None:
128
+ return coqui_tts_cache
129
+
130
+ try:
131
+ # If you have a GPU on HF Spaces, you can set gpu=True.
132
+ # If not, set gpu=False to run on CPU (slower).
133
+ coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
134
+ except Exception as e:
135
+ raise RuntimeError("Failed to load Coqui XTTS-v2 TTS: %s" % e)
136
+
137
+ return coqui_tts_cache
138
+
139
+ def run_coqui_tts(text, lang):
140
+ """
141
+ Run Coqui TTS for Chinese or Japanese text.
142
+ We specify the language code from coqui_lang_map.
143
+ Returns (sample_rate, waveform).
144
+ """
145
+ coqui_tts = load_coqui_tts()
146
+ lang_code = coqui_lang_map[lang] # "zh" or "ja"
147
+
148
+ # We must output to a file, then read it back.
149
+ # Use a temporary file to store the wave.
150
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
151
+ tmp_name = tmp.name
152
+
153
+ try:
154
+ coqui_tts.tts_to_file(
155
+ text=text,
156
+ file_path=tmp_name,
157
+ language=lang_code # no speaker_wav, default voice
158
+ )
159
+ data, sr = sf.read(tmp_name)
160
+ finally:
161
+ # Cleanup the temporary file
162
+ if os.path.exists(tmp_name):
163
+ os.remove(tmp_name)
164
+
165
+ return sr, data
166
 
167
  # ------------------------------------------------------
168
+ # 7. Main Prediction Function
169
  # ------------------------------------------------------
170
  def predict(audio, text, target_language):
171
  """
172
  1. Get English text (ASR if audio provided, else text).
173
  2. Translate to target_language.
174
+ 3. TTS with the chosen approach:
175
+ - Spanish -> MMS TTS (VITS)
176
+ - Chinese/Japanese -> Coqui XTTS-v2
177
  """
178
  # Step 1: English text
179
  if text.strip():
 
209
 
210
  # Step 3: TTS
211
  try:
212
+ if target_language == SPANISH:
213
+ sr, waveform = run_spanish_tts(translated_text)
214
  else:
215
+ # Chinese or Japanese
216
+ sr, waveform = run_coqui_tts(translated_text, target_language)
217
  except Exception as e:
218
  return english_text, translated_text, f"TTS error: {e}"
219
 
220
  return english_text, translated_text, (sr, waveform)
221
 
222
  # ------------------------------------------------------
223
+ # 8. Gradio Interface
224
  # ------------------------------------------------------
225
  iface = gr.Interface(
226
  fn=predict,
227
  inputs=[
228
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
229
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
230
+ gr.Dropdown(choices=[SPANISH, CHINESE, JAPANESE], value=SPANISH, label="Target Language")
231
  ],
232
  outputs=[
233
  gr.Textbox(label="English Transcription"),
 
236
  ],
237
  title="Multimodal Language Learning Aid",
238
  description=(
239
+ "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
240
+ "2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP).\n"
241
+ "3. Synthesizes speech:\n"
242
  " - Spanish -> facebook/mms-tts-spa (VITS)\n"
243
+ " - Chinese & Japanese -> Coqui XTTS-v2 (multilingual TTS)\n\n"
244
+ "Note: The Coqui model is 'tts_models/multilingual/multi-dataset/xtts_v2' and expects language codes.\n"
245
+ "If you need voice cloning, set `speaker_wav` in `tts_to_file()`. By default, it uses a single generic voice."
246
  ),
247
  allow_flagging="never"
248
  )
249
 
250
  if __name__ == "__main__":
251
+ iface.launch(server_name="0.0.0.0", server_port=7860)