Yilin0601 commited on
Commit
cc8959c
·
verified ·
1 Parent(s): 5fb2e7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -69
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import librosa
 
5
  from transformers import pipeline, VitsModel, AutoTokenizer
6
- import scipy # if needed for processing
7
 
8
  # ------------------------------------------------------
9
- # 1. ASR Pipeline (English) using Wav2Vec2
10
  # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
@@ -29,36 +30,33 @@ translation_tasks = {
29
  }
30
 
31
  # ------------------------------------------------------
32
- # 3. TTS Model Configurations
33
- # - Spanish: facebook/mms-tts-spa
34
- # - Chinese: myshell-ai/MeloTTS-Chinese
35
- # - Japanese: myshell-ai/MeloTTS-Japanese
36
  # ------------------------------------------------------
37
- tts_config = {
38
- "Spanish": {
39
- "model_id": "facebook/mms-tts-spa",
40
- "architecture": "vits"
41
- },
42
- "Chinese": {
43
- "model_id": "myshell-ai/MeloTTS-Chinese",
44
- "architecture": "vits"
45
- },
46
- "Japanese": {
47
- "model_id": "myshell-ai/MeloTTS-Japanese",
48
- "architecture": "vits"
49
- }
50
  }
51
 
52
  # ------------------------------------------------------
53
- # 4. Caches
54
  # ------------------------------------------------------
55
  translator_cache = {}
56
- tts_model_cache = {} # store (model, tokenizer, architecture)
 
 
57
 
58
- # ------------------------------------------------------
59
- # 5. Translator Helper
60
- # ------------------------------------------------------
61
  def get_translator(lang):
 
 
 
62
  if lang in translator_cache:
63
  return translator_cache[lang]
64
  model_name = translation_models[lang]
@@ -67,66 +65,91 @@ def get_translator(lang):
67
  translator_cache[lang] = translator
68
  return translator
69
 
70
- # ------------------------------------------------------
71
- # 6. TTS Loading Helper
72
- # ------------------------------------------------------
73
- def get_tts_model(lang):
74
  """
75
- Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
76
  """
77
- if lang in tts_model_cache:
78
- return tts_model_cache[lang]
79
-
80
- config = tts_config.get(lang)
81
- if config is None:
82
- raise ValueError(f"No TTS config found for language: {lang}")
83
-
84
- model_id = config["model_id"]
85
- arch = config["architecture"]
86
 
87
  try:
88
- # Attempt VITS-based loading
89
  model = VitsModel.from_pretrained(model_id)
90
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
- raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
93
 
94
- tts_model_cache[lang] = (model, tokenizer, arch)
95
- return tts_model_cache[lang]
 
 
 
 
 
 
 
 
 
96
 
97
  # ------------------------------------------------------
98
- # 7. TTS Inference Helper
99
  # ------------------------------------------------------
100
- def run_tts_inference(lang, text):
101
  """
102
- Generates waveform using the loaded TTS model and tokenizer.
103
- Returns (sample_rate, np_array).
104
  """
105
- model, tokenizer, arch = get_tts_model(lang)
106
  inputs = tokenizer(text, return_tensors="pt")
107
-
108
  with torch.no_grad():
109
  output = model(**inputs)
110
-
111
- # VitsModel output is typically `.waveform`
112
  if not hasattr(output, "waveform"):
113
- raise RuntimeError("TTS model output does not contain 'waveform' attribute.")
114
-
115
- waveform_tensor = output.waveform
116
- waveform = waveform_tensor.squeeze().cpu().numpy()
117
-
118
- # Typically 16 kHz for these VITS models
119
  sample_rate = 16000
120
- return (sample_rate, waveform)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # ------------------------------------------------------
123
- # 8. Prediction Function
124
  # ------------------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
- 1. Obtain English text (ASR with Wav2Vec2 or text input).
128
- 2. Translate English -> target_language.
129
- 3. TTS for that language (using configured models).
130
  """
131
  # Step 1: English text
132
  if text.strip():
@@ -138,7 +161,7 @@ def predict(audio, text, target_language):
138
  if audio_data.dtype not in [np.float32, np.float64]:
139
  audio_data = audio_data.astype(np.float32)
140
 
141
- # Stereo -> mono if needed
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
@@ -162,15 +185,18 @@ def predict(audio, text, target_language):
162
 
163
  # Step 3: TTS
164
  try:
165
- sample_rate, waveform = run_tts_inference(target_language, translated_text)
 
 
 
 
166
  except Exception as e:
167
- # Return error info in place of audio
168
  return english_text, translated_text, f"TTS error: {e}"
169
 
170
- return english_text, translated_text, (sample_rate, waveform)
171
 
172
  # ------------------------------------------------------
173
- # 9. Gradio Interface
174
  # ------------------------------------------------------
175
  iface = gr.Interface(
176
  fn=predict,
@@ -187,11 +213,15 @@ iface = gr.Interface(
187
  title="Multimodal Language Learning Aid",
188
  description=(
189
  "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
190
- "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP models).\n"
191
- "3. Provides synthetic speech with TTS models.\n"
 
 
 
 
192
  ),
193
  allow_flagging="never"
194
  )
195
 
196
  if __name__ == "__main__":
197
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import torch
3
  import numpy as np
4
  import librosa
5
+ import soundfile as sf # likely needed by the pipeline or local saving
6
  from transformers import pipeline, VitsModel, AutoTokenizer
7
+ from datasets import load_dataset
8
 
9
  # ------------------------------------------------------
10
+ # 1. ASR Pipeline (English) - Wav2Vec2
11
  # ------------------------------------------------------
12
  asr = pipeline(
13
  "automatic-speech-recognition",
 
30
  }
31
 
32
  # ------------------------------------------------------
33
+ # 3. TTS Configuration
34
+ # - Spanish: VITS-based MMS TTS
35
+ # - Chinese & Japanese: Microsoft SpeechT5
 
36
  # ------------------------------------------------------
37
+ # We'll store them as keys for convenience
38
+ SPANISH_KEY = "Spanish"
39
+ CHINESE_KEY = "Chinese"
40
+ JAPANESE_KEY = "Japanese"
41
+
42
+ # VITS config for Spanish only
43
+ mms_spanish_config = {
44
+ "model_id": "facebook/mms-tts-spa",
45
+ "architecture": "vits"
 
 
 
 
46
  }
47
 
48
  # ------------------------------------------------------
49
+ # 4. Create TTS Pipelines / Models Once (Caching)
50
  # ------------------------------------------------------
51
  translator_cache = {}
52
+ vits_model_cache = None # for Spanish
53
+ speech_t5_pipeline_cache = None # for Chinese/Japanese
54
+ speech_t5_speaker_embedding = None
55
 
 
 
 
56
  def get_translator(lang):
57
+ """
58
+ Return a cached MarianMT translator for the specified language.
59
+ """
60
  if lang in translator_cache:
61
  return translator_cache[lang]
62
  model_name = translation_models[lang]
 
65
  translator_cache[lang] = translator
66
  return translator
67
 
68
+ def load_spanish_vits():
 
 
 
69
  """
70
+ Load and cache the Spanish VITS model + tokenizer (facebook/mms-tts-spa).
71
  """
72
+ global vits_model_cache
73
+ if vits_model_cache is not None:
74
+ return vits_model_cache
 
 
 
 
 
 
75
 
76
  try:
77
+ model_id = mms_spanish_config["model_id"]
78
  model = VitsModel.from_pretrained(model_id)
79
  tokenizer = AutoTokenizer.from_pretrained(model_id)
80
+ vits_model_cache = (model, tokenizer)
81
+ except Exception as e:
82
+ raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
83
+
84
+ return vits_model_cache
85
+
86
+ def load_speech_t5_pipeline():
87
+ """
88
+ Load and cache the Microsoft SpeechT5 text-to-speech pipeline
89
+ and a default speaker embedding.
90
+ """
91
+ global speech_t5_pipeline_cache, speech_t5_speaker_embedding
92
+ if speech_t5_pipeline_cache is not None and speech_t5_speaker_embedding is not None:
93
+ return speech_t5_pipeline_cache, speech_t5_speaker_embedding
94
+
95
+ try:
96
+ # Create the pipeline
97
+ # The pipeline is named "text-to-speech" in Transformers >= 4.29
98
+ t5_pipe = pipeline("text-to-speech", model="microsoft/speecht5_tts")
99
  except Exception as e:
100
+ raise RuntimeError(f"Failed to load Microsoft SpeechT5 pipeline: {e}")
101
 
102
+ # Load a default speaker embedding
103
+ try:
104
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
105
+ # Just pick an arbitrary index for speaker embedding
106
+ speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
107
+ except Exception as e:
108
+ raise RuntimeError(f"Failed to load default speaker embedding: {e}")
109
+
110
+ speech_t5_pipeline_cache = t5_pipe
111
+ speech_t5_speaker_embedding = speaker_embedding
112
+ return t5_pipe, speaker_embedding
113
 
114
  # ------------------------------------------------------
115
+ # 5. TTS Inference Helpers
116
  # ------------------------------------------------------
117
+ def run_vits_inference(text):
118
  """
119
+ For Spanish TTS using MMS (facebook/mms-tts-spa).
 
120
  """
121
+ model, tokenizer = load_spanish_vits()
122
  inputs = tokenizer(text, return_tensors="pt")
 
123
  with torch.no_grad():
124
  output = model(**inputs)
 
 
125
  if not hasattr(output, "waveform"):
126
+ raise RuntimeError("VITS output does not contain 'waveform'.")
127
+ waveform = output.waveform.squeeze().cpu().numpy()
 
 
 
 
128
  sample_rate = 16000
129
+ return sample_rate, waveform
130
+
131
+ def run_speecht5_inference(text):
132
+ """
133
+ For Chinese & Japanese TTS using Microsoft SpeechT5 pipeline.
134
+ """
135
+ t5_pipe, speaker_embedding = load_speech_t5_pipeline()
136
+ # The pipeline returns a dict with 'audio' (numpy) and 'sampling_rate'
137
+ result = t5_pipe(
138
+ text,
139
+ forward_params={"speaker_embeddings": speaker_embedding}
140
+ )
141
+ waveform = result["audio"]
142
+ sample_rate = result["sampling_rate"]
143
+ return sample_rate, waveform
144
 
145
  # ------------------------------------------------------
146
+ # 6. Main Prediction Function
147
  # ------------------------------------------------------
148
  def predict(audio, text, target_language):
149
  """
150
+ 1. Get English text (ASR if audio provided, else text).
151
+ 2. Translate to target_language.
152
+ 3. TTS with the chosen approach (VITS for Spanish, SpeechT5 for Chinese/Japanese).
153
  """
154
  # Step 1: English text
155
  if text.strip():
 
161
  if audio_data.dtype not in [np.float32, np.float64]:
162
  audio_data = audio_data.astype(np.float32)
163
 
164
+ # Stereo -> mono
165
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
166
  audio_data = np.mean(audio_data, axis=1)
167
 
 
185
 
186
  # Step 3: TTS
187
  try:
188
+ if target_language == SPANISH_KEY:
189
+ sr, waveform = run_vits_inference(translated_text)
190
+ else:
191
+ # Chinese or Japanese -> SpeechT5
192
+ sr, waveform = run_speecht5_inference(translated_text)
193
  except Exception as e:
 
194
  return english_text, translated_text, f"TTS error: {e}"
195
 
196
+ return english_text, translated_text, (sr, waveform)
197
 
198
  # ------------------------------------------------------
199
+ # 7. Gradio Interface
200
  # ------------------------------------------------------
201
  iface = gr.Interface(
202
  fn=predict,
 
213
  title="Multimodal Language Learning Aid",
214
  description=(
215
  "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
216
+ "2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP models).\n"
217
+ "3. Synthesizes speech:\n"
218
+ " - Spanish -> facebook/mms-tts-spa (VITS)\n"
219
+ " - Chinese & Japanese -> microsoft/speecht5_tts (SpeechT5)\n\n"
220
+ "Note: SpeechT5 is not officially trained for Japanese, so results may vary.\n"
221
+ "You can also try inputting short, clear audio for best ASR results."
222
  ),
223
  allow_flagging="never"
224
  )
225
 
226
  if __name__ == "__main__":
227
+ iface.launch(server_name="0.0.0.0", server_port=7860)