Yilin0601 commited on
Commit
16d930f
·
verified ·
1 Parent(s): cdf5e7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -33
app.py CHANGED
@@ -6,11 +6,11 @@ from transformers import pipeline, VitsModel, AutoTokenizer
6
  import scipy # if needed for processing
7
 
8
  # ------------------------------------------------------
9
- # 1. ASR Pipeline (English)
10
  # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
- model="facebook/wav2vec2-base-960h"
14
  )
15
 
16
  # ------------------------------------------------------
@@ -30,17 +30,20 @@ translation_tasks = {
30
 
31
  # ------------------------------------------------------
32
  # 3. TTS Model Configurations
33
- # NOTE: MMS does not provide a Mandarin TTS model,
34
- # so we skip TTS for Chinese.
35
  # ------------------------------------------------------
36
  tts_config = {
37
  "Spanish": {
38
  "model_id": "facebook/mms-tts-spa", # MMS Spanish
39
  "architecture": "vits"
40
  },
41
- "Chinese": None, # No MMS TTS for Chinese
 
 
 
42
  "Japanese": {
43
- "model_id": "facebook/mms-tts-jpn", # MMS Japanese
44
  "architecture": "vits"
45
  }
46
  }
@@ -69,21 +72,19 @@ def get_translator(lang):
69
  def get_tts_model(lang):
70
  """
71
  Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
72
- If no config is found (e.g. for Chinese), raises ValueError.
73
  """
74
  if lang in tts_model_cache:
75
  return tts_model_cache[lang]
76
 
77
  config = tts_config.get(lang)
78
  if config is None:
79
- # No TTS model for this language
80
  raise ValueError(f"No TTS config found for language: {lang}")
81
 
82
  model_id = config["model_id"]
83
  arch = config["architecture"]
84
 
85
  try:
86
- # Since arch == "vits" for these examples, load VitsModel + AutoTokenizer
87
  model = VitsModel.from_pretrained(model_id)
88
  tokenizer = AutoTokenizer.from_pretrained(model_id)
89
  except Exception as e:
@@ -106,17 +107,14 @@ def run_tts_inference(lang, text):
106
  with torch.no_grad():
107
  output = model(**inputs)
108
 
109
- # VitsModel output is typically `.waveform`
110
  if hasattr(output, "waveform"):
111
  waveform_tensor = output.waveform
112
  else:
113
  raise RuntimeError("TTS model output does not contain 'waveform'.")
114
 
115
- # Convert to numpy
116
  waveform = waveform_tensor.squeeze().cpu().numpy()
117
-
118
- # MMS TTS typically uses 16 kHz
119
- sample_rate = 16000
120
  return (sample_rate, waveform)
121
 
122
  # ------------------------------------------------------
@@ -124,25 +122,25 @@ def run_tts_inference(lang, text):
124
  # ------------------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
- 1. Obtain English text (from text input or ASR).
128
- 2. Translate English -> target_language.
129
- 3. Run VITS-based TTS for that language (if available).
130
  """
131
- # Step 1: English text
132
  if text.strip():
133
  english_text = text.strip()
134
  elif audio is not None:
135
  sample_rate, audio_data = audio
136
 
137
- # Convert to float32
138
  if audio_data.dtype not in [np.float32, np.float64]:
139
  audio_data = audio_data.astype(np.float32)
140
 
141
- # Convert stereo to mono if needed
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
145
- # Resample to 16k if needed
146
  if sample_rate != 16000:
147
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
148
 
@@ -160,11 +158,8 @@ def predict(audio, text, target_language):
160
  except Exception as e:
161
  return english_text, f"Translation error: {e}", None
162
 
163
- # Step 3: TTS (skip if no config for language)
164
  try:
165
- if tts_config[target_language] is None:
166
- # No TTS model for Chinese or not supported
167
- return english_text, translated_text, None
168
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
169
  except Exception as e:
170
  return english_text, translated_text, f"TTS error: {e}"
@@ -184,20 +179,17 @@ iface = gr.Interface(
184
  outputs=[
185
  gr.Textbox(label="English Transcription"),
186
  gr.Textbox(label="Translation (Target Language)"),
187
- gr.Audio(label="Synthesized Speech (if available)")
188
  ],
189
- title="Multimodal Language Learning Aid (MMS TTS / VITS)",
190
  description=(
191
  "This app:\n"
192
- "1. Transcribes English speech (via ASR) or accepts English text.\n"
193
- "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP).\n"
194
- "3. Synthesizes speech with VITS-based MMS TTS models for Spanish/Japanese.\n\n"
195
- "Note: MMS does NOT currently provide a Mandarin TTS model, so TTS is skipped for Chinese."
196
  ),
197
  allow_flagging="never"
198
  )
199
 
200
  if __name__ == "__main__":
201
- # If running locally, uncomment:
202
- # iface.launch()
203
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
6
  import scipy # if needed for processing
7
 
8
  # ------------------------------------------------------
9
+ # 1. ASR Pipeline (English) using Whisper-small
10
  # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
+ model="openai/whisper-small"
14
  )
15
 
16
  # ------------------------------------------------------
 
30
 
31
  # ------------------------------------------------------
32
  # 3. TTS Model Configurations
33
+ # For Spanish, we keep the MMS TTS.
34
+ # For Chinese & Japanese, use myshell-ai/MeloTTS-Chinese.
35
  # ------------------------------------------------------
36
  tts_config = {
37
  "Spanish": {
38
  "model_id": "facebook/mms-tts-spa", # MMS Spanish
39
  "architecture": "vits"
40
  },
41
+ "Chinese": {
42
+ "model_id": "myshell-ai/MeloTTS-Chinese",
43
+ "architecture": "vits"
44
+ },
45
  "Japanese": {
46
+ "model_id": "myshell-ai/MeloTTS-Japanese",
47
  "architecture": "vits"
48
  }
49
  }
 
72
  def get_tts_model(lang):
73
  """
74
  Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
 
75
  """
76
  if lang in tts_model_cache:
77
  return tts_model_cache[lang]
78
 
79
  config = tts_config.get(lang)
80
  if config is None:
 
81
  raise ValueError(f"No TTS config found for language: {lang}")
82
 
83
  model_id = config["model_id"]
84
  arch = config["architecture"]
85
 
86
  try:
87
+ # Assuming the model follows VITS-based inference
88
  model = VitsModel.from_pretrained(model_id)
89
  tokenizer = AutoTokenizer.from_pretrained(model_id)
90
  except Exception as e:
 
107
  with torch.no_grad():
108
  output = model(**inputs)
109
 
110
+ # VitsModel output is typically provided via .waveform attribute
111
  if hasattr(output, "waveform"):
112
  waveform_tensor = output.waveform
113
  else:
114
  raise RuntimeError("TTS model output does not contain 'waveform'.")
115
 
 
116
  waveform = waveform_tensor.squeeze().cpu().numpy()
117
+ sample_rate = 16000 # Typically used sample rate for these models
 
 
118
  return (sample_rate, waveform)
119
 
120
  # ------------------------------------------------------
 
122
  # ------------------------------------------------------
123
  def predict(audio, text, target_language):
124
  """
125
+ 1. Obtain English text (via ASR using Whisper-small or text input).
126
+ 2. Translate English text to the target language.
127
+ 3. Synthesize speech with the target language TTS model.
128
  """
129
+ # Step 1: Get English text
130
  if text.strip():
131
  english_text = text.strip()
132
  elif audio is not None:
133
  sample_rate, audio_data = audio
134
 
135
+ # Ensure float32 data type
136
  if audio_data.dtype not in [np.float32, np.float64]:
137
  audio_data = audio_data.astype(np.float32)
138
 
139
+ # Convert stereo to mono if necessary
140
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
141
  audio_data = np.mean(audio_data, axis=1)
142
 
143
+ # Resample to 16kHz if necessary
144
  if sample_rate != 16000:
145
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
146
 
 
158
  except Exception as e:
159
  return english_text, f"Translation error: {e}", None
160
 
161
+ # Step 3: TTS
162
  try:
 
 
 
163
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
164
  except Exception as e:
165
  return english_text, translated_text, f"TTS error: {e}"
 
179
  outputs=[
180
  gr.Textbox(label="English Transcription"),
181
  gr.Textbox(label="Translation (Target Language)"),
182
+ gr.Audio(label="Synthesized Speech")
183
  ],
184
+ title="Multimodal Language Learning Aid (ASR / TTS)",
185
  description=(
186
  "This app:\n"
187
+ "1. Transcribes English speech or English text.\n"
188
+ "2. Translates to Spanish, Chinese, or Japanese (using Helsinki-NLP models).\n"
189
+ "3. Provides synthetic speech with TTS models:\n"
 
190
  ),
191
  allow_flagging="never"
192
  )
193
 
194
  if __name__ == "__main__":
 
 
195
  iface.launch(server_name="0.0.0.0", server_port=7860)