Yilin0601 commited on
Commit
1ee4794
·
verified ·
1 Parent(s): 7064b79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -37
app.py CHANGED
@@ -29,24 +29,21 @@ translation_tasks = {
29
  }
30
 
31
  # -----------------------------------------------
32
- # 3. TTS Model Configurations
33
- # We'll load them manually (not with pipeline("text-to-speech"))
34
  # -----------------------------------------------
35
- # - Spanish (MMS TTS, uses VITS architecture)
36
- # - Chinese (MMS TTS, uses VITS architecture)
37
- # - Japanese (SpeechT5 or a VITS-based model—here we pick a SpeechT5 example)
38
  tts_config = {
39
  "Spanish": {
40
  "model_id": "facebook/mms-tts-spa",
41
- "architecture": "vits" # We'll use VitsModel
42
  },
43
  "Chinese": {
44
  "model_id": "facebook/mms-tts-che",
45
  "architecture": "vits"
46
  },
47
  "Japanese": {
48
- "model_id": "esnya/japanese_speecht5_tts",
49
- "architecture": "speecht5" # We'll treat this differently
50
  }
51
  }
52
 
@@ -69,7 +66,7 @@ def get_translator(lang):
69
  return translator
70
 
71
  # -----------------------------------------------
72
- # 6. TTS Helper
73
  # -----------------------------------------------
74
  def get_tts_model(lang):
75
  """
@@ -86,25 +83,18 @@ def get_tts_model(lang):
86
  arch = config["architecture"]
87
 
88
  try:
89
- if arch == "vits":
90
- # Load a VitsModel + tokenizer
91
- model = VitsModel.from_pretrained(model_id)
92
- tokenizer = AutoTokenizer.from_pretrained(model_id)
93
- elif arch == "speecht5":
94
- # For a SpeechT5 model, we might do something else
95
- # e.g., pipeline("text-to-speech", model=...) if it works
96
- # or custom loading if it's also a VITS-based approach
97
- # We'll attempt a similar pattern:
98
- model = VitsModel.from_pretrained(model_id)
99
- tokenizer = AutoTokenizer.from_pretrained(model_id)
100
- else:
101
- raise ValueError(f"Unknown TTS architecture: {arch}")
102
  except Exception as e:
103
  raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
104
 
105
  tts_model_cache[lang] = (model, tokenizer, arch)
106
  return tts_model_cache[lang]
107
 
 
 
 
108
  def run_tts_inference(lang, text):
109
  """
110
  Generates waveform using the loaded TTS model and tokenizer.
@@ -120,25 +110,23 @@ def run_tts_inference(lang, text):
120
  if hasattr(output, "waveform"):
121
  waveform_tensor = output.waveform
122
  else:
123
- # Some models might return a different attribute
124
- raise RuntimeError("The TTS model output doesn't have 'waveform' attribute.")
125
 
126
- # Convert to numpy array
127
  waveform = waveform_tensor.squeeze().cpu().numpy()
128
 
129
- # Typically, MMS TTS uses 16 kHz
130
  sample_rate = 16000
131
  return (sample_rate, waveform)
132
 
133
  # -----------------------------------------------
134
- # 7. Prediction Function
135
  # -----------------------------------------------
136
  def predict(audio, text, target_language):
137
  """
138
- 1. If text is provided, use it directly as English text.
139
- Else, if audio is provided, run ASR.
140
  2. Translate English -> target_language.
141
- 3. Run TTS with the correct approach for that language.
142
  """
143
  # Step 1: English text
144
  if text.strip():
@@ -150,7 +138,7 @@ def predict(audio, text, target_language):
150
  if audio_data.dtype not in [np.float32, np.float64]:
151
  audio_data = audio_data.astype(np.float32)
152
 
153
- # Mono
154
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
155
  audio_data = np.mean(audio_data, axis=1)
156
 
@@ -181,7 +169,7 @@ def predict(audio, text, target_language):
181
  return english_text, translated_text, (sample_rate, waveform)
182
 
183
  # -----------------------------------------------
184
- # 8. Gradio Interface
185
  # -----------------------------------------------
186
  iface = gr.Interface(
187
  fn=predict,
@@ -195,14 +183,14 @@ iface = gr.Interface(
195
  gr.Textbox(label="Translation (Target Language)"),
196
  gr.Audio(label="Synthesized Speech in Target Language")
197
  ],
198
- title="Multimodal Language Learning Aid (VITS-based TTS)",
199
  description=(
200
  "This app:\n"
201
  "1. Transcribes English speech (via ASR) or accepts English text.\n"
202
- "2. Translates to Spanish, Chinese, or Japanese.\n"
203
- "3. Synthesizes speech with VITS-based or SpeechT5-based models.\n\n"
204
- "Note: Some models are experimental and may produce errors or poor quality.\n"
205
- "Either upload/record English audio or enter text, then select a target language."
206
  ),
207
  allow_flagging="never"
208
  )
 
29
  }
30
 
31
  # -----------------------------------------------
32
+ # 3. TTS Model Configurations (All VITS)
 
33
  # -----------------------------------------------
34
+ # Make sure these model IDs exist on Hugging Face.
 
 
35
  tts_config = {
36
  "Spanish": {
37
  "model_id": "facebook/mms-tts-spa",
38
+ "architecture": "vits"
39
  },
40
  "Chinese": {
41
  "model_id": "facebook/mms-tts-che",
42
  "architecture": "vits"
43
  },
44
  "Japanese": {
45
+ "model_id": "facebook/mms-tts-jpn",
46
+ "architecture": "vits"
47
  }
48
  }
49
 
 
66
  return translator
67
 
68
  # -----------------------------------------------
69
+ # 6. TTS Loading Helper
70
  # -----------------------------------------------
71
  def get_tts_model(lang):
72
  """
 
83
  arch = config["architecture"]
84
 
85
  try:
86
+ # Since arch == "vits" for all three languages, we load VitsModel + AutoTokenizer
87
+ model = VitsModel.from_pretrained(model_id)
88
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
89
  except Exception as e:
90
  raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
91
 
92
  tts_model_cache[lang] = (model, tokenizer, arch)
93
  return tts_model_cache[lang]
94
 
95
+ # -----------------------------------------------
96
+ # 7. TTS Inference Helper
97
+ # -----------------------------------------------
98
  def run_tts_inference(lang, text):
99
  """
100
  Generates waveform using the loaded TTS model and tokenizer.
 
110
  if hasattr(output, "waveform"):
111
  waveform_tensor = output.waveform
112
  else:
113
+ raise RuntimeError("TTS model output does not contain 'waveform'.")
 
114
 
115
+ # Convert to numpy
116
  waveform = waveform_tensor.squeeze().cpu().numpy()
117
 
118
+ # MMS TTS typically uses 16 kHz
119
  sample_rate = 16000
120
  return (sample_rate, waveform)
121
 
122
  # -----------------------------------------------
123
+ # 8. Prediction Function
124
  # -----------------------------------------------
125
  def predict(audio, text, target_language):
126
  """
127
+ 1. Obtain English text (from text input or ASR).
 
128
  2. Translate English -> target_language.
129
+ 3. Run VITS-based TTS for that language.
130
  """
131
  # Step 1: English text
132
  if text.strip():
 
138
  if audio_data.dtype not in [np.float32, np.float64]:
139
  audio_data = audio_data.astype(np.float32)
140
 
141
+ # Convert stereo to mono if needed
142
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
143
  audio_data = np.mean(audio_data, axis=1)
144
 
 
169
  return english_text, translated_text, (sample_rate, waveform)
170
 
171
  # -----------------------------------------------
172
+ # 9. Gradio Interface
173
  # -----------------------------------------------
174
  iface = gr.Interface(
175
  fn=predict,
 
183
  gr.Textbox(label="Translation (Target Language)"),
184
  gr.Audio(label="Synthesized Speech in Target Language")
185
  ],
186
+ title="Multimodal Language Learning Aid (MMS TTS / VITS)",
187
  description=(
188
  "This app:\n"
189
  "1. Transcribes English speech (via ASR) or accepts English text.\n"
190
+ "2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP).\n"
191
+ "3. Synthesizes speech with VITS-based MMS TTS models.\n\n"
192
+ "Note: Ensure the MMS model IDs exist on Hugging Face. If not, you'll see an error.\n"
193
+ "Record/upload English audio or enter text, then select a target language."
194
  ),
195
  allow_flagging="never"
196
  )