Yilin0601 commited on
Commit
d744aff
·
verified ·
1 Parent(s): c098e72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -53
app.py CHANGED
@@ -2,19 +2,20 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import librosa
5
- from transformers import pipeline
 
6
 
7
- # --------------------------------------------------
8
- # ASR Pipeline (for English transcription)
9
- # --------------------------------------------------
10
  asr = pipeline(
11
  "automatic-speech-recognition",
12
  model="facebook/wav2vec2-base-960h"
13
  )
14
 
15
- # --------------------------------------------------
16
- # Mapping for Target Languages (Spanish, Chinese, Japanese)
17
- # --------------------------------------------------
18
  translation_models = {
19
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
20
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
@@ -27,62 +28,143 @@ translation_tasks = {
27
  "Japanese": "translation_en_to_ja"
28
  }
29
 
30
- tts_models = {
31
- "Spanish": "facebook/mms-tts-spa",
32
- "Chinese": "facebook/mms-tts-che",
33
- "Japanese": "esnya/japanese_speecht5_tts"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
 
36
- # --------------------------------------------------
37
- # Caches for translator and TTS pipelines
38
- # --------------------------------------------------
39
  translator_cache = {}
40
- tts_cache = {}
41
 
42
- def get_translator(target_language):
43
- if target_language in translator_cache:
44
- return translator_cache[target_language]
45
- model_name = translation_models[target_language]
46
- task_name = translation_tasks[target_language]
 
 
 
47
  translator = pipeline(task_name, model=model_name)
48
- translator_cache[target_language] = translator
49
  return translator
50
 
51
- def get_tts(target_language):
52
- if target_language in tts_cache:
53
- return tts_cache[target_language]
54
- model_name = tts_models.get(target_language)
55
- if model_name is None:
56
- raise ValueError(f"No TTS model available for {target_language}.")
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
- tts_pipeline = pipeline("text-to-speech", model=model_name)
 
 
 
 
 
 
 
 
 
 
 
 
59
  except Exception as e:
60
- raise ValueError(f"Failed to load TTS model for {target_language} with model '{model_name}'.\nError: {e}")
61
- tts_cache[target_language] = tts_pipeline
62
- return tts_pipeline
 
63
 
64
- # --------------------------------------------------
65
- # Prediction Function
66
- # --------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(audio, text, target_language):
68
- # Step 1: Obtain English text from text input if provided, otherwise use ASR.
 
 
 
 
 
 
69
  if text.strip():
70
  english_text = text.strip()
71
  elif audio is not None:
72
  sample_rate, audio_data = audio
 
 
73
  if audio_data.dtype not in [np.float32, np.float64]:
74
  audio_data = audio_data.astype(np.float32)
 
 
75
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
76
  audio_data = np.mean(audio_data, axis=1)
 
 
77
  if sample_rate != 16000:
78
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
79
- input_audio = {"array": audio_data, "sampling_rate": 16000}
80
- asr_result = asr(input_audio)
 
81
  english_text = asr_result["text"]
82
  else:
83
  return "No input provided.", "", None
84
 
85
- # Step 2: Translate the English text to the target language.
86
  translator = get_translator(target_language)
87
  try:
88
  translation_result = translator(english_text)
@@ -90,38 +172,37 @@ def predict(audio, text, target_language):
90
  except Exception as e:
91
  return english_text, f"Translation error: {e}", None
92
 
93
- # Step 3: Synthesize speech using the TTS pipeline.
94
  try:
95
- tts_pipeline = get_tts(target_language)
96
- tts_result = tts_pipeline(translated_text)
97
- synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
98
  except Exception as e:
99
  return english_text, translated_text, f"TTS error: {e}"
100
 
101
- return english_text, translated_text, synthesized_audio
102
 
103
- # --------------------------------------------------
104
- # Gradio Interface Setup
105
- # --------------------------------------------------
106
  iface = gr.Interface(
107
  fn=predict,
108
  inputs=[
109
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
110
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
111
- gr.Dropdown(choices=list(translation_models.keys()), value="Spanish", label="Target Language")
112
  ],
113
  outputs=[
114
  gr.Textbox(label="English Transcription"),
115
  gr.Textbox(label="Translation (Target Language)"),
116
  gr.Audio(label="Synthesized Speech in Target Language")
117
  ],
118
- title="Multimodal Language Learning Aid",
119
  description=(
120
- "This app provides three outputs:\n"
121
- "1. English transcription (from ASR or text input),\n"
122
- "2. Translation to Spanish, Chinese, or Japanese (using Helsinki-NLP models), and\n"
123
- "3. Synthetic speech in the target language (using Facebook MMS TTS or equivalent).\n\n"
124
- "Either record/upload an English audio sample or enter English text directly."
 
125
  ),
126
  allow_flagging="never"
127
  )
 
2
  import torch
3
  import numpy as np
4
  import librosa
5
+ from transformers import pipeline, VitsModel, AutoTokenizer
6
+ import scipy # if needed for processing
7
 
8
+ # -----------------------------------------------
9
+ # 1. ASR Pipeline (English)
10
+ # -----------------------------------------------
11
  asr = pipeline(
12
  "automatic-speech-recognition",
13
  model="facebook/wav2vec2-base-960h"
14
  )
15
 
16
+ # -----------------------------------------------
17
+ # 2. Translation Models (3 languages)
18
+ # -----------------------------------------------
19
  translation_models = {
20
  "Spanish": "Helsinki-NLP/opus-mt-en-es",
21
  "Chinese": "Helsinki-NLP/opus-mt-en-zh",
 
28
  "Japanese": "translation_en_to_ja"
29
  }
30
 
31
+ # -----------------------------------------------
32
+ # 3. TTS Model Configurations
33
+ # We'll load them manually (not with pipeline("text-to-speech"))
34
+ # -----------------------------------------------
35
+ # - Spanish (MMS TTS, uses VITS architecture)
36
+ # - Chinese (MMS TTS, uses VITS architecture)
37
+ # - Japanese (SpeechT5 or a VITS-based model—here we pick a SpeechT5 example)
38
+ tts_config = {
39
+ "Spanish": {
40
+ "model_id": "facebook/mms-tts-spa",
41
+ "architecture": "vits" # We'll use VitsModel
42
+ },
43
+ "Chinese": {
44
+ "model_id": "facebook/mms-tts-che",
45
+ "architecture": "vits"
46
+ },
47
+ "Japanese": {
48
+ "model_id": "esnya/japanese_speecht5_tts",
49
+ "architecture": "speecht5" # We'll treat this differently
50
+ }
51
  }
52
 
53
+ # -----------------------------------------------
54
+ # 4. Caches
55
+ # -----------------------------------------------
56
  translator_cache = {}
57
+ tts_model_cache = {} # store (model, tokenizer, architecture)
58
 
59
+ # -----------------------------------------------
60
+ # 5. Translator Helper
61
+ # -----------------------------------------------
62
+ def get_translator(lang):
63
+ if lang in translator_cache:
64
+ return translator_cache[lang]
65
+ model_name = translation_models[lang]
66
+ task_name = translation_tasks[lang]
67
  translator = pipeline(task_name, model=model_name)
68
+ translator_cache[lang] = translator
69
  return translator
70
 
71
+ # -----------------------------------------------
72
+ # 6. TTS Helper
73
+ # -----------------------------------------------
74
+ def get_tts_model(lang):
75
+ """
76
+ Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
77
+ """
78
+ if lang in tts_model_cache:
79
+ return tts_model_cache[lang]
80
+
81
+ config = tts_config.get(lang)
82
+ if config is None:
83
+ raise ValueError(f"No TTS config found for language: {lang}")
84
+
85
+ model_id = config["model_id"]
86
+ arch = config["architecture"]
87
+
88
  try:
89
+ if arch == "vits":
90
+ # Load a VitsModel + tokenizer
91
+ model = VitsModel.from_pretrained(model_id)
92
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
93
+ elif arch == "speecht5":
94
+ # For a SpeechT5 model, we might do something else
95
+ # e.g., pipeline("text-to-speech", model=...) if it works
96
+ # or custom loading if it's also a VITS-based approach
97
+ # We'll attempt a similar pattern:
98
+ model = VitsModel.from_pretrained(model_id)
99
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
100
+ else:
101
+ raise ValueError(f"Unknown TTS architecture: {arch}")
102
  except Exception as e:
103
+ raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
104
+
105
+ tts_model_cache[lang] = (model, tokenizer, arch)
106
+ return tts_model_cache[lang]
107
 
108
+ def run_tts_inference(lang, text):
109
+ """
110
+ Generates waveform using the loaded TTS model and tokenizer.
111
+ Returns (sample_rate, np_array).
112
+ """
113
+ model, tokenizer, arch = get_tts_model(lang)
114
+ inputs = tokenizer(text, return_tensors="pt")
115
+
116
+ with torch.no_grad():
117
+ output = model(**inputs)
118
+
119
+ # VitsModel output is typically `.waveform`
120
+ if hasattr(output, "waveform"):
121
+ waveform_tensor = output.waveform
122
+ else:
123
+ # Some models might return a different attribute
124
+ raise RuntimeError("The TTS model output doesn't have 'waveform' attribute.")
125
+
126
+ # Convert to numpy array
127
+ waveform = waveform_tensor.squeeze().cpu().numpy()
128
+
129
+ # Typically, MMS TTS uses 16 kHz
130
+ sample_rate = 16000
131
+ return (sample_rate, waveform)
132
+
133
+ # -----------------------------------------------
134
+ # 7. Prediction Function
135
+ # -----------------------------------------------
136
  def predict(audio, text, target_language):
137
+ """
138
+ 1. If text is provided, use it directly as English text.
139
+ Else, if audio is provided, run ASR.
140
+ 2. Translate English -> target_language.
141
+ 3. Run TTS with the correct approach for that language.
142
+ """
143
+ # Step 1: English text
144
  if text.strip():
145
  english_text = text.strip()
146
  elif audio is not None:
147
  sample_rate, audio_data = audio
148
+
149
+ # Convert to float32
150
  if audio_data.dtype not in [np.float32, np.float64]:
151
  audio_data = audio_data.astype(np.float32)
152
+
153
+ # Mono
154
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
155
  audio_data = np.mean(audio_data, axis=1)
156
+
157
+ # Resample to 16k
158
  if sample_rate != 16000:
159
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
160
+
161
+ asr_input = {"array": audio_data, "sampling_rate": 16000}
162
+ asr_result = asr(asr_input)
163
  english_text = asr_result["text"]
164
  else:
165
  return "No input provided.", "", None
166
 
167
+ # Step 2: Translation
168
  translator = get_translator(target_language)
169
  try:
170
  translation_result = translator(english_text)
 
172
  except Exception as e:
173
  return english_text, f"Translation error: {e}", None
174
 
175
+ # Step 3: TTS
176
  try:
177
+ sample_rate, waveform = run_tts_inference(target_language, translated_text)
 
 
178
  except Exception as e:
179
  return english_text, translated_text, f"TTS error: {e}"
180
 
181
+ return english_text, translated_text, (sample_rate, waveform)
182
 
183
+ # -----------------------------------------------
184
+ # 8. Gradio Interface
185
+ # -----------------------------------------------
186
  iface = gr.Interface(
187
  fn=predict,
188
  inputs=[
189
  gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
190
  gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
191
+ gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language")
192
  ],
193
  outputs=[
194
  gr.Textbox(label="English Transcription"),
195
  gr.Textbox(label="Translation (Target Language)"),
196
  gr.Audio(label="Synthesized Speech in Target Language")
197
  ],
198
+ title="Multimodal Language Learning Aid (VITS-based TTS)",
199
  description=(
200
+ "This app:\n"
201
+ "1. Transcribes English speech (via ASR) or accepts English text.\n"
202
+ "2. Translates to Spanish, Chinese, or Japanese.\n"
203
+ "3. Synthesizes speech with VITS-based or SpeechT5-based models.\n\n"
204
+ "Note: Some models are experimental and may produce errors or poor quality.\n"
205
+ "Either upload/record English audio or enter text, then select a target language."
206
  ),
207
  allow_flagging="never"
208
  )