Yilin0601 commited on
Commit
5f61133
·
verified ·
1 Parent(s): c996092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -225
app.py CHANGED
@@ -1,225 +1,13 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import librosa
5
- import soundfile as sf
6
- import tempfile
7
- import os
8
-
9
- from transformers import pipeline, VitsModel, AutoTokenizer
10
- from datasets import load_dataset
11
-
12
- # For MeloTTS (Chinese and Japanese)
13
- try:
14
- from melo.api import TTS as MeloTTS
15
- except ImportError:
16
- raise ImportError("Please install the MeloTTS package (e.g., pip install myshell-ai/MeloTTS-Chinese)")
17
-
18
- # ------------------------------------------------------
19
- # 1. ASR Pipeline (English) using Wav2Vec2
20
- # ------------------------------------------------------
21
- asr = pipeline(
22
- "automatic-speech-recognition",
23
- model="facebook/wav2vec2-base-960h"
24
- )
25
-
26
- # ------------------------------------------------------
27
- # 2. Translation Models (8 languages)
28
- # ------------------------------------------------------
29
- translation_models = {
30
- "Spanish": "Helsinki-NLP/opus-mt-en-es",
31
- "Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
32
- "Indonesian": "Helsinki-NLP/opus-mt-en-id",
33
- "Turkish": "Helsinki-NLP/opus-mt-en-trk",
34
- "Portuguese": "Helsinki-NLP/opus-mt-tc-big-en-pt",
35
- "Korean": "Helsinki-NLP/opus-mt-tc-big-en-ko",
36
- "Chinese": "Helsinki-NLP/opus-mt-en-zh",
37
- "Japanese": "Helsinki-NLP/opus-mt-en-jap"
38
- }
39
-
40
- translation_tasks = {
41
- "Spanish": "translation_en_to_es",
42
- "Vietnamese": "translation_en_to_vi",
43
- "Indonesian": "translation_en_to_id",
44
- "Turkish": "translation_en_to_tr",
45
- "Portuguese": "translation_en_to_pt",
46
- "Korean": "translation_en_to-ko",
47
- "Chinese": "translation_en_to_zh",
48
- "Japanese": "translation_en_to_ja"
49
- }
50
-
51
- # ------------------------------------------------------
52
- # 3. TTS Configuration
53
- # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
- # - MeloTTS for: Chinese and Japanese
55
- # ------------------------------------------------------
56
- tts_config = {
57
- "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
58
- "Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
59
- "Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
60
- "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
- "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
- "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
- "Chinese": {"type": "melo"},
64
- "Japanese": {"type": "melo"}
65
- }
66
-
67
- # ------------------------------------------------------
68
- # 4. Global Caches for Translators and TTS Models
69
- # ------------------------------------------------------
70
- translator_cache = {}
71
- mms_tts_cache = {} # For MMS (VITS-based) TTS models
72
- melo_tts_cache = {} # For MeloTTS models (Chinese/Japanese)
73
-
74
- # ------------------------------------------------------
75
- # 5. Translator Helper
76
- # ------------------------------------------------------
77
- def get_translator(lang):
78
- if lang in translator_cache:
79
- return translator_cache[lang]
80
- model_name = translation_models[lang]
81
- task_name = translation_tasks[lang]
82
- translator = pipeline(task_name, model=model_name)
83
- translator_cache[lang] = translator
84
- return translator
85
-
86
- # ------------------------------------------------------
87
- # 6. MMS TTS (VITS) Helper for languages using MMS TTS
88
- # ------------------------------------------------------
89
- def load_mms_tts(lang):
90
- if lang in mms_tts_cache:
91
- return mms_tts_cache[lang]
92
- config = tts_config[lang]
93
- try:
94
- model = VitsModel.from_pretrained(config["model_id"])
95
- tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
96
- mms_tts_cache[lang] = (model, tokenizer)
97
- except Exception as e:
98
- raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
99
- return mms_tts_cache[lang]
100
-
101
- def run_mms_tts(text, lang):
102
- model, tokenizer = load_mms_tts(lang)
103
- inputs = tokenizer(text, return_tensors="pt")
104
- with torch.no_grad():
105
- output = model(**inputs)
106
- if not hasattr(output, "waveform"):
107
- raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
108
- waveform = output.waveform.squeeze().cpu().numpy()
109
- sample_rate = 16000
110
- return sample_rate, waveform
111
-
112
- # ------------------------------------------------------
113
- # 7. MeloTTS Helper for Chinese and Japanese
114
- # ------------------------------------------------------
115
- def run_melo_tts(text, lang):
116
- """
117
- Uses the myshell-ai MeloTTS model.
118
- For Chinese, use language parameter 'ZH'; for Japanese, use 'JP'.
119
- """
120
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
- lang_param = 'ZH' if lang == "Chinese" else 'JP'
122
- if lang not in melo_tts_cache:
123
- try:
124
- model = MeloTTS(language=lang_param, device=device)
125
- melo_tts_cache[lang] = model
126
- except Exception as e:
127
- raise RuntimeError(f"Failed to load MeloTTS model for {lang}: {e}")
128
- else:
129
- model = melo_tts_cache[lang]
130
- speaker_ids = model.hps.data.spk2id
131
- # Assume the speaker key is the same as lang_param
132
- speaker_key = lang_param
133
- speed = 1.0
134
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
135
- tmp_name = tmp.name
136
- try:
137
- model.tts_to_file(text, speaker_ids[speaker_key], tmp_name, speed=speed)
138
- data, sr = sf.read(tmp_name)
139
- finally:
140
- if os.path.exists(tmp_name):
141
- os.remove(tmp_name)
142
- return sr, data
143
-
144
- # ------------------------------------------------------
145
- # 8. Main Prediction Function
146
- # ------------------------------------------------------
147
- def predict(audio, text, target_language):
148
- """
149
- 1. Obtain English text (via ASR if audio provided, else text).
150
- 2. Translate the English text to target_language.
151
- 3. Generate TTS audio using either MMS TTS (VITS) or MeloTTS.
152
- """
153
- # Step 1: Get English text.
154
- if text.strip():
155
- english_text = text.strip()
156
- elif audio is not None:
157
- sample_rate, audio_data = audio
158
- if audio_data.dtype not in [np.float32, np.float64]:
159
- audio_data = audio_data.astype(np.float32)
160
- if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
161
- audio_data = np.mean(audio_data, axis=1)
162
- if sample_rate != 16000:
163
- audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
164
- asr_input = {"array": audio_data, "sampling_rate": 16000}
165
- asr_result = asr(asr_input)
166
- english_text = asr_result["text"]
167
- else:
168
- return "No input provided.", "", None
169
-
170
- # Step 2: Translate.
171
- translator = get_translator(target_language)
172
- try:
173
- translation_result = translator(english_text)
174
- translated_text = translation_result[0]["translation_text"]
175
- except Exception as e:
176
- return english_text, f"Translation error: {e}", None
177
-
178
- # Step 3: TTS.
179
- try:
180
- tts_type = tts_config[target_language]["type"]
181
- if tts_type == "mms":
182
- sr, waveform = run_mms_tts(translated_text, target_language)
183
- elif tts_type == "melo":
184
- sr, waveform = run_melo_tts(translated_text, target_language)
185
- else:
186
- raise RuntimeError("Unknown TTS type for target language.")
187
- except Exception as e:
188
- return english_text, translated_text, f"TTS error: {e}"
189
-
190
- return english_text, translated_text, (sr, waveform)
191
-
192
- # ------------------------------------------------------
193
- # 9. Gradio Interface
194
- # ------------------------------------------------------
195
- language_choices = [
196
- "Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
197
- ]
198
-
199
- iface = gr.Interface(
200
- fn=predict,
201
- inputs=[
202
- gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
203
- gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
204
- gr.Dropdown(choices=language_choices, value="Spanish", label="Target Language")
205
- ],
206
- outputs=[
207
- gr.Textbox(label="English Transcription"),
208
- gr.Textbox(label="Translation (Target Language)"),
209
- gr.Audio(label="Synthesized Speech")
210
- ],
211
- title="Multimodal Language Learning Aid",
212
- description=(
213
- "This app performs the following steps:\n"
214
- "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
215
- "2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
216
- "3. Synthesizes speech:\n"
217
- " - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
218
- " - For Chinese and Japanese: uses myshell-ai MeloTTS models.\n"
219
- "\nSelect your target language from the dropdown."
220
- ),
221
- allow_flagging="never"
222
- )
223
-
224
- if __name__ == "__main__":
225
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ torch
2
+ transformers>=4.33.0
3
+ gradio
4
+ librosa
5
+ numpy
6
+ scipy
7
+ accelerate
8
+ sentencepiece
9
+ soundfile
10
+ datasets
11
+ TTS
12
+ git+https://github.com/myshell-ai/MeloTTS-Chinese.git
13
+ git+https://github.com/myshell-ai/MeloTTS-Japanese.git