Yilin0601 commited on
Commit
a3f86f5
·
verified ·
1 Parent(s): 9002343

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -99
app.py CHANGED
@@ -6,13 +6,10 @@ import soundfile as sf
6
  import tempfile
7
  import os
8
 
9
- from transformers import (
10
- pipeline,
11
- VitsModel,
12
- AutoTokenizer
13
- )
14
 
15
- # For Coqui TTS
16
  try:
17
  from TTS.api import TTS as CoquiTTS
18
  except ImportError:
@@ -27,52 +24,63 @@ asr = pipeline(
27
  )
28
 
29
  # ------------------------------------------------------
30
- # 2. Translation Models (3 languages)
31
  # ------------------------------------------------------
32
  translation_models = {
33
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
 
 
 
 
 
34
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
35
  "Japanese": "Helsinki-NLP/opus-mt-en-ja"
36
  }
37
 
38
  translation_tasks = {
39
  "Spanish": "translation_en_to_es",
 
 
 
 
 
40
  "Chinese": "translation_en_to_zh",
41
  "Japanese": "translation_en_to_ja"
42
  }
43
 
44
  # ------------------------------------------------------
45
- # 3. TTS Config:
46
- # - Spanish: MMS TTS (facebook/mms-tts-spa)
47
- # - Chinese, Japanese: Coqui XTTS-v2 (tts_models/multilingual/multi-dataset/xtts_v2)
48
  # ------------------------------------------------------
49
- SPANISH = "Spanish"
50
- CHINESE = "Chinese"
51
- JAPANESE = "Japanese"
52
-
53
- # For Spanish (MMS)
54
- mms_spanish_config = {
55
- "model_id": "facebook/mms-tts-spa",
56
- "architecture": "vits"
 
57
  }
58
 
59
- # We'll map Chinese/Japanese to Coqui language codes
60
  coqui_lang_map = {
61
- CHINESE: "zh",
62
- JAPANESE: "ja"
63
  }
64
 
65
  # ------------------------------------------------------
66
- # 4. Global Caches
67
  # ------------------------------------------------------
68
  translator_cache = {}
69
- spanish_vits_cache = None
70
- coqui_tts_cache = None
71
 
 
 
 
72
  def get_translator(lang):
73
- """
74
- Return a cached MarianMT translator for the specified language.
75
- """
76
  if lang in translator_cache:
77
  return translator_cache[lang]
78
  model_name = translation_models[lang]
@@ -82,124 +90,90 @@ def get_translator(lang):
82
  return translator
83
 
84
  # ------------------------------------------------------
85
- # 5. Spanish TTS: MMS (VITS)
86
  # ------------------------------------------------------
87
- def load_spanish_vits():
88
- """
89
- Load and cache the Spanish MMS TTS model (VITS).
90
- """
91
- global spanish_vits_cache
92
- if spanish_vits_cache is not None:
93
- return spanish_vits_cache
94
-
95
  try:
96
- model = VitsModel.from_pretrained(mms_spanish_config["model_id"])
97
- tokenizer = AutoTokenizer.from_pretrained(mms_spanish_config["model_id"])
98
- spanish_vits_cache = (model, tokenizer)
99
  except Exception as e:
100
- raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
101
-
102
- return spanish_vits_cache
103
 
104
- def run_spanish_tts(text):
105
- """
106
- Run MMS TTS (VITS) for Spanish text.
107
- Returns (sample_rate, waveform).
108
- """
109
- model, tokenizer = load_spanish_vits()
110
  inputs = tokenizer(text, return_tensors="pt")
111
  with torch.no_grad():
112
  output = model(**inputs)
113
  if not hasattr(output, "waveform"):
114
- raise RuntimeError("Spanish TTS model output does not contain 'waveform'.")
115
  waveform = output.waveform.squeeze().cpu().numpy()
116
  sample_rate = 16000
117
  return sample_rate, waveform
118
 
119
  # ------------------------------------------------------
120
- # 6. Chinese/Japanese TTS: Coqui XTTS-v2
121
  # ------------------------------------------------------
122
  def load_coqui_tts():
123
- """
124
- Load and cache the Coqui XTTS-v2 model (multilingual).
125
- """
126
  global coqui_tts_cache
127
  if coqui_tts_cache is not None:
128
  return coqui_tts_cache
129
-
130
  try:
131
- # If you have a GPU on HF Spaces, you can set gpu=True.
132
- # If not, set gpu=False to run on CPU (slower).
133
  coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
134
  except Exception as e:
135
- raise RuntimeError("Failed to load Coqui XTTS-v2 TTS: %s" % e)
136
-
137
  return coqui_tts_cache
138
 
139
  def run_coqui_tts(text, lang):
140
- """
141
- Run Coqui TTS for Chinese or Japanese text.
142
- We specify the language code from coqui_lang_map.
143
- Returns (sample_rate, waveform).
144
- """
145
  coqui_tts = load_coqui_tts()
146
- lang_code = coqui_lang_map[lang] # "zh" or "ja"
147
-
148
- # We must output to a file, then read it back.
149
- # Use a temporary file to store the wave.
150
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
151
  tmp_name = tmp.name
152
-
153
  try:
154
  coqui_tts.tts_to_file(
155
  text=text,
156
  file_path=tmp_name,
157
- language=lang_code # no speaker_wav, default voice
158
  )
159
  data, sr = sf.read(tmp_name)
160
  finally:
161
- # Cleanup the temporary file
162
  if os.path.exists(tmp_name):
163
  os.remove(tmp_name)
164
-
165
  return sr, data
166
 
167
  # ------------------------------------------------------
168
- # 7. Main Prediction Function
169
  # ------------------------------------------------------
170
  def predict(audio, text, target_language):
171
  """
172
- 1. Get English text (ASR if audio provided, else text).
173
- 2. Translate to target_language.
174
- 3. TTS with the chosen approach:
175
- - Spanish -> MMS TTS (VITS)
176
- - Chinese/Japanese -> Coqui XTTS-v2
177
  """
178
- # Step 1: English text
179
  if text.strip():
180
  english_text = text.strip()
181
  elif audio is not None:
182
  sample_rate, audio_data = audio
183
-
184
- # Convert to float32 if needed
185
  if audio_data.dtype not in [np.float32, np.float64]:
186
  audio_data = audio_data.astype(np.float32)
187
-
188
- # Stereo -> mono
189
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
190
  audio_data = np.mean(audio_data, axis=1)
191
-
192
- # Resample to 16k if needed
193
  if sample_rate != 16000:
194
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
195
-
196
  asr_input = {"array": audio_data, "sampling_rate": 16000}
197
  asr_result = asr(asr_input)
198
  english_text = asr_result["text"]
199
  else:
200
  return "No input provided.", "", None
201
 
202
- # Step 2: Translate
203
  translator = get_translator(target_language)
204
  try:
205
  translation_result = translator(english_text)
@@ -207,27 +181,33 @@ def predict(audio, text, target_language):
207
  except Exception as e:
208
  return english_text, f"Translation error: {e}", None
209
 
210
- # Step 3: TTS
211
  try:
212
- if target_language == SPANISH:
213
- sr, waveform = run_spanish_tts(translated_text)
214
- else:
215
- # Chinese or Japanese
216
  sr, waveform = run_coqui_tts(translated_text, target_language)
 
 
217
  except Exception as e:
218
  return english_text, translated_text, f"TTS error: {e}"
219
 
220
  return english_text, translated_text, (sr, waveform)
221
 
222
  # ------------------------------------------------------
223
- # 8. Gradio Interface
224
  # ------------------------------------------------------
 
 
 
 
225
  iface = gr.Interface(
226
  fn=predict,
227
  inputs=[
228
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
229
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
230
- gr.Dropdown(choices=[SPANISH, CHINESE, JAPANESE], value=SPANISH, label="Target Language")
231
  ],
232
  outputs=[
233
  gr.Textbox(label="English Transcription"),
@@ -236,13 +216,14 @@ iface = gr.Interface(
236
  ],
237
  title="Multimodal Language Learning Aid",
238
  description=(
239
- "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
240
- "2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP).\n"
 
241
  "3. Synthesizes speech:\n"
242
- " - Spanish -> facebook/mms-tts-spa (VITS)\n"
243
- " - Chinese & Japanese -> Coqui XTTS-v2 (multilingual TTS)\n\n"
244
- "Note: The Coqui model is 'tts_models/multilingual/multi-dataset/xtts_v2' and expects language codes.\n"
245
- "If you need voice cloning, set `speaker_wav` in `tts_to_file()`. By default, it uses a single generic voice."
246
  ),
247
  allow_flagging="never"
248
  )
 
6
  import tempfile
7
  import os
8
 
9
+ from transformers import pipeline, VitsModel, AutoTokenizer
10
+ from datasets import load_dataset
 
 
 
11
 
12
+ # For Coqui TTS (XTTS-v2)
13
  try:
14
  from TTS.api import TTS as CoquiTTS
15
  except ImportError:
 
24
  )
25
 
26
  # ------------------------------------------------------
27
+ # 2. Translation Models (8 languages)
28
  # ------------------------------------------------------
29
  translation_models = {
30
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
31
+ "Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
32
+ "Indonesian": "Helsinki-NLP/opus-mt-en-id",
33
+ "Turkish": "Helsinki-NLP/opus-mt-en-tr",
34
+ "Portuguese": "Helsinki-NLP/opus-mt-en-pt",
35
+ "Korean": "Helsinki-NLP/opus-mt-en-ko",
36
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
37
  "Japanese": "Helsinki-NLP/opus-mt-en-ja"
38
  }
39
 
40
  translation_tasks = {
41
  "Spanish": "translation_en_to_es",
42
+ "Vietnamese": "translation_en_to_vi",
43
+ "Indonesian": "translation_en_to_id",
44
+ "Turkish": "translation_en_to_tr",
45
+ "Portuguese": "translation_en_to_pt",
46
+ "Korean": "translation_en_to-ko",
47
  "Chinese": "translation_en_to_zh",
48
  "Japanese": "translation_en_to_ja"
49
  }
50
 
51
  # ------------------------------------------------------
52
+ # 3. TTS Configuration
53
+ # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
+ # - Coqui XTTS-v2 for: Chinese and Japanese
55
  # ------------------------------------------------------
56
+ tts_config = {
57
+ "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
58
+ "Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
59
+ "Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
60
+ "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
+ "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
+ "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
+ "Chinese": {"type": "coqui"},
64
+ "Japanese": {"type": "coqui"}
65
  }
66
 
67
+ # For Coqui, we map our languages to language codes expected by the model.
68
  coqui_lang_map = {
69
+ "Chinese": "zh",
70
+ "Japanese": "ja"
71
  }
72
 
73
  # ------------------------------------------------------
74
+ # 4. Global Caches for Translators and TTS Models
75
  # ------------------------------------------------------
76
  translator_cache = {}
77
+ mms_tts_cache = {} # For MMS (VITS-based) TTS models
78
+ coqui_tts_cache = None # Single instance for Coqui XTTS-v2
79
 
80
+ # ------------------------------------------------------
81
+ # 5. Translator Helper
82
+ # ------------------------------------------------------
83
  def get_translator(lang):
 
 
 
84
  if lang in translator_cache:
85
  return translator_cache[lang]
86
  model_name = translation_models[lang]
 
90
  return translator
91
 
92
  # ------------------------------------------------------
93
+ # 6. MMS TTS (VITS) Helper for languages using MMS TTS
94
  # ------------------------------------------------------
95
+ def load_mms_tts(lang):
96
+ if lang in mms_tts_cache:
97
+ return mms_tts_cache[lang]
98
+ config = tts_config[lang]
 
 
 
 
99
  try:
100
+ model = VitsModel.from_pretrained(config["model_id"])
101
+ tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
102
+ mms_tts_cache[lang] = (model, tokenizer)
103
  except Exception as e:
104
+ raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
105
+ return mms_tts_cache[lang]
 
106
 
107
+ def run_mms_tts(text, lang):
108
+ model, tokenizer = load_mms_tts(lang)
 
 
 
 
109
  inputs = tokenizer(text, return_tensors="pt")
110
  with torch.no_grad():
111
  output = model(**inputs)
112
  if not hasattr(output, "waveform"):
113
+ raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
114
  waveform = output.waveform.squeeze().cpu().numpy()
115
  sample_rate = 16000
116
  return sample_rate, waveform
117
 
118
  # ------------------------------------------------------
119
+ # 7. Coqui TTS Helper for Chinese and Japanese
120
  # ------------------------------------------------------
121
  def load_coqui_tts():
 
 
 
122
  global coqui_tts_cache
123
  if coqui_tts_cache is not None:
124
  return coqui_tts_cache
 
125
  try:
126
+ # Set gpu=True if a GPU is available.
 
127
  coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
128
  except Exception as e:
129
+ raise RuntimeError(f"Failed to load Coqui XTTS-v2 TTS: {e}")
 
130
  return coqui_tts_cache
131
 
132
  def run_coqui_tts(text, lang):
 
 
 
 
 
133
  coqui_tts = load_coqui_tts()
134
+ lang_code = coqui_lang_map[lang] # "zh" for Chinese or "ja" for Japanese
135
+ # Write the output to a temporary file and then read it back.
 
 
136
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
137
  tmp_name = tmp.name
 
138
  try:
139
  coqui_tts.tts_to_file(
140
  text=text,
141
  file_path=tmp_name,
142
+ language=lang_code # using default voice; for cloning, add speaker_wav parameter
143
  )
144
  data, sr = sf.read(tmp_name)
145
  finally:
 
146
  if os.path.exists(tmp_name):
147
  os.remove(tmp_name)
 
148
  return sr, data
149
 
150
  # ------------------------------------------------------
151
+ # 8. Main Prediction Function
152
  # ------------------------------------------------------
153
  def predict(audio, text, target_language):
154
  """
155
+ 1. Obtain English text (via ASR if audio provided, else text).
156
+ 2. Translate English text to target_language.
157
+ 3. Generate TTS audio using either MMS TTS (VITS) or Coqui XTTS-v2.
 
 
158
  """
159
+ # Step 1: Get English text.
160
  if text.strip():
161
  english_text = text.strip()
162
  elif audio is not None:
163
  sample_rate, audio_data = audio
 
 
164
  if audio_data.dtype not in [np.float32, np.float64]:
165
  audio_data = audio_data.astype(np.float32)
 
 
166
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
167
  audio_data = np.mean(audio_data, axis=1)
 
 
168
  if sample_rate != 16000:
169
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
 
170
  asr_input = {"array": audio_data, "sampling_rate": 16000}
171
  asr_result = asr(asr_input)
172
  english_text = asr_result["text"]
173
  else:
174
  return "No input provided.", "", None
175
 
176
+ # Step 2: Translate.
177
  translator = get_translator(target_language)
178
  try:
179
  translation_result = translator(english_text)
 
181
  except Exception as e:
182
  return english_text, f"Translation error: {e}", None
183
 
184
+ # Step 3: TTS.
185
  try:
186
+ tts_type = tts_config[target_language]["type"]
187
+ if tts_type == "mms":
188
+ sr, waveform = run_mms_tts(translated_text, target_language)
189
+ elif tts_type == "coqui":
190
  sr, waveform = run_coqui_tts(translated_text, target_language)
191
+ else:
192
+ raise RuntimeError("Unknown TTS type for target language.")
193
  except Exception as e:
194
  return english_text, translated_text, f"TTS error: {e}"
195
 
196
  return english_text, translated_text, (sr, waveform)
197
 
198
  # ------------------------------------------------------
199
+ # 9. Gradio Interface
200
  # ------------------------------------------------------
201
+ language_choices = [
202
+ "Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
203
+ ]
204
+
205
  iface = gr.Interface(
206
  fn=predict,
207
  inputs=[
208
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
209
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
210
+ gr.Dropdown(choices=language_choices, value="Spanish", label="Target Language")
211
  ],
212
  outputs=[
213
  gr.Textbox(label="English Transcription"),
 
216
  ],
217
  title="Multimodal Language Learning Aid",
218
  description=(
219
+ "This app performs the following steps:\n"
220
+ "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
221
+ "2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
222
  "3. Synthesizes speech:\n"
223
+ " - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: "
224
+ "uses Facebook MMS TTS (VITS-based).\n"
225
+ " - For Chinese and Japanese: uses Coqui XTTS-v2.\n"
226
+ "\nSelect your target language from the dropdown."
227
  ),
228
  allow_flagging="never"
229
  )