Yilin0601 commited on
Commit
5fb2e7c
·
verified ·
1 Parent(s): 25763d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -25
app.py CHANGED
@@ -6,11 +6,11 @@ from transformers import pipeline, VitsModel, AutoTokenizer
6
  import scipy # if needed for processing
7
 
8
  # ------------------------------------------------------
9
- # 1. ASR Pipeline (English) using Whisper-small
10
  # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
- model="openai/whisper-small"
14
  )
15
 
16
  # ------------------------------------------------------
@@ -30,12 +30,13 @@ translation_tasks = {
30
 
31
  # ------------------------------------------------------
32
  # 3. TTS Model Configurations
33
- # For Spanish, we keep the MMS TTS.
34
- # For Chinese & Japanese, use myshell-ai/MeloTTS-Chinese.
 
35
  # ------------------------------------------------------
36
  tts_config = {
37
  "Spanish": {
38
- "model_id": "facebook/mms-tts-spa", # MMS Spanish
39
  "architecture": "vits"
40
  },
41
  "Chinese": {
@@ -84,7 +85,7 @@ def get_tts_model(lang):
84
  arch = config["architecture"]
85
 
86
  try:
87
- # Assuming the model follows VITS-based inference
88
  model = VitsModel.from_pretrained(model_id)
89
  tokenizer = AutoTokenizer.from_pretrained(model_id)
90
  except Exception as e:
@@ -107,14 +108,15 @@ def run_tts_inference(lang, text):
107
  with torch.no_grad():
108
  output = model(**inputs)
109
 
110
- # VitsModel output is typically provided via .waveform attribute
111
- if hasattr(output, "waveform"):
112
- waveform_tensor = output.waveform
113
- else:
114
- raise RuntimeError("TTS model output does not contain 'waveform'.")
115
 
 
116
  waveform = waveform_tensor.squeeze().cpu().numpy()
117
- sample_rate = 16000 # Typically used sample rate for these models
 
 
118
  return (sample_rate, waveform)
119
 
120
  # ------------------------------------------------------
@@ -122,25 +124,25 @@ def run_tts_inference(lang, text):
122
  # ------------------------------------------------------
123
  def predict(audio, text, target_language):
124
  """
125
- 1. Obtain English text (via ASR using Whisper-small or text input).
126
- 2. Translate English text to the target language.
127
- 3. Synthesize speech with the target language TTS model.
128
  """
129
- # Step 1: Get English text
130
  if text.strip():
131
  english_text = text.strip()
132
  elif audio is not None:
133
  sample_rate, audio_data = audio
134
 
135
- # Ensure float32 data type
136
  if audio_data.dtype not in [np.float32, np.float64]:
137
  audio_data = audio_data.astype(np.float32)
138
 
139
- # Convert stereo to mono if necessary
140
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
141
  audio_data = np.mean(audio_data, axis=1)
142
 
143
- # Resample to 16kHz if necessary
144
  if sample_rate != 16000:
145
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
146
 
@@ -150,7 +152,7 @@ def predict(audio, text, target_language):
150
  else:
151
  return "No input provided.", "", None
152
 
153
- # Step 2: Translation
154
  translator = get_translator(target_language)
155
  try:
156
  translation_result = translator(english_text)
@@ -162,6 +164,7 @@ def predict(audio, text, target_language):
162
  try:
163
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
164
  except Exception as e:
 
165
  return english_text, translated_text, f"TTS error: {e}"
166
 
167
  return english_text, translated_text, (sample_rate, waveform)
@@ -181,12 +184,11 @@ iface = gr.Interface(
181
  gr.Textbox(label="Translation (Target Language)"),
182
  gr.Audio(label="Synthesized Speech")
183
  ],
184
- title="Multimodal Language Learning Aid (ASR / TTS)",
185
  description=(
186
- "This app:\n"
187
- "1. Transcribes English speech or English text.\n"
188
- "2. Translates to Spanish, Chinese, or Japanese (using Helsinki-NLP models).\n"
189
- "3. Provides synthetic speech with TTS models:\n"
190
  ),
191
  allow_flagging="never"
192
  )
 
6
  import scipy # if needed for processing
7
 
8
  # ------------------------------------------------------
9
+ # 1. ASR Pipeline (English) using Wav2Vec2
10
  # ------------------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
+ model="facebook/wav2vec2-base-960h"
14
  )
15
 
16
  # ------------------------------------------------------
 
30
 
31
  # ------------------------------------------------------
32
  # 3. TTS Model Configurations
33
+ # - Spanish: facebook/mms-tts-spa
34
+ # - Chinese: myshell-ai/MeloTTS-Chinese
35
+ # - Japanese: myshell-ai/MeloTTS-Japanese
36
  # ------------------------------------------------------
37
  tts_config = {
38
  "Spanish": {
39
+ "model_id": "facebook/mms-tts-spa",
40
  "architecture": "vits"
41
  },
42
  "Chinese": {
 
85
  arch = config["architecture"]
86
 
87
  try:
88
+ # Attempt VITS-based loading
89
  model = VitsModel.from_pretrained(model_id)
90
  tokenizer = AutoTokenizer.from_pretrained(model_id)
91
  except Exception as e:
 
108
  with torch.no_grad():
109
  output = model(**inputs)
110
 
111
+ # VitsModel output is typically `.waveform`
112
+ if not hasattr(output, "waveform"):
113
+ raise RuntimeError("TTS model output does not contain 'waveform' attribute.")
 
 
114
 
115
+ waveform_tensor = output.waveform
116
  waveform = waveform_tensor.squeeze().cpu().numpy()
117
+
118
+ # Typically 16 kHz for these VITS models
119
+ sample_rate = 16000
120
  return (sample_rate, waveform)
121
 
122
  # ------------------------------------------------------
 
124
  # ------------------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
+ 1. Obtain English text (ASR with Wav2Vec2 or text input).
128
+ 2. Translate English -> target_language.
129
+ 3. TTS for that language (using configured models).
130
  """
131
+ # Step 1: English text
132
  if text.strip():
133
  english_text = text.strip()
134
  elif audio is not None:
135
  sample_rate, audio_data = audio
136
 
137
+ # Convert to float32 if needed
138
  if audio_data.dtype not in [np.float32, np.float64]:
139
  audio_data = audio_data.astype(np.float32)
140
 
141
+ # Stereo -> mono if needed
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
145
+ # Resample to 16k if needed
146
  if sample_rate != 16000:
147
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
148
 
 
152
  else:
153
  return "No input provided.", "", None
154
 
155
+ # Step 2: Translate
156
  translator = get_translator(target_language)
157
  try:
158
  translation_result = translator(english_text)
 
164
  try:
165
  sample_rate, waveform = run_tts_inference(target_language, translated_text)
166
  except Exception as e:
167
+ # Return error info in place of audio
168
  return english_text, translated_text, f"TTS error: {e}"
169
 
170
  return english_text, translated_text, (sample_rate, waveform)
 
184
  gr.Textbox(label="Translation (Target Language)"),
185
  gr.Audio(label="Synthesized Speech")
186
  ],
187
+ title="Multimodal Language Learning Aid",
188
  description=(
189
+ "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
190
+ "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP models).\n"
191
+ "3. Provides synthetic speech with TTS models.\n"
 
192
  ),
193
  allow_flagging="never"
194
  )