Yilin0601's picture
Update app.py
a3f86f5 verified
raw
history blame
9 kB
import gradio as gr
import torch
import numpy as np
import librosa
import soundfile as sf
import tempfile
import os
from transformers import pipeline, VitsModel, AutoTokenizer
from datasets import load_dataset
# For Coqui TTS (XTTS-v2)
try:
from TTS.api import TTS as CoquiTTS
except ImportError:
raise ImportError("Please install Coqui TTS via `pip install TTS`.")
# ------------------------------------------------------
# 1. ASR Pipeline (English) using Wav2Vec2
# ------------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# ------------------------------------------------------
# 2. Translation Models (8 languages)
# ------------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
"Indonesian": "Helsinki-NLP/opus-mt-en-id",
"Turkish": "Helsinki-NLP/opus-mt-en-tr",
"Portuguese": "Helsinki-NLP/opus-mt-en-pt",
"Korean": "Helsinki-NLP/opus-mt-en-ko",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Japanese": "Helsinki-NLP/opus-mt-en-ja"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"Vietnamese": "translation_en_to_vi",
"Indonesian": "translation_en_to_id",
"Turkish": "translation_en_to_tr",
"Portuguese": "translation_en_to_pt",
"Korean": "translation_en_to-ko",
"Chinese": "translation_en_to_zh",
"Japanese": "translation_en_to_ja"
}
# ------------------------------------------------------
# 3. TTS Configuration
# - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
# - Coqui XTTS-v2 for: Chinese and Japanese
# ------------------------------------------------------
tts_config = {
"Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
"Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
"Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
"Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
"Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
"Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
"Chinese": {"type": "coqui"},
"Japanese": {"type": "coqui"}
}
# For Coqui, we map our languages to language codes expected by the model.
coqui_lang_map = {
"Chinese": "zh",
"Japanese": "ja"
}
# ------------------------------------------------------
# 4. Global Caches for Translators and TTS Models
# ------------------------------------------------------
translator_cache = {}
mms_tts_cache = {} # For MMS (VITS-based) TTS models
coqui_tts_cache = None # Single instance for Coqui XTTS-v2
# ------------------------------------------------------
# 5. Translator Helper
# ------------------------------------------------------
def get_translator(lang):
if lang in translator_cache:
return translator_cache[lang]
model_name = translation_models[lang]
task_name = translation_tasks[lang]
translator = pipeline(task_name, model=model_name)
translator_cache[lang] = translator
return translator
# ------------------------------------------------------
# 6. MMS TTS (VITS) Helper for languages using MMS TTS
# ------------------------------------------------------
def load_mms_tts(lang):
if lang in mms_tts_cache:
return mms_tts_cache[lang]
config = tts_config[lang]
try:
model = VitsModel.from_pretrained(config["model_id"])
tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
mms_tts_cache[lang] = (model, tokenizer)
except Exception as e:
raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
return mms_tts_cache[lang]
def run_mms_tts(text, lang):
model, tokenizer = load_mms_tts(lang)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
if not hasattr(output, "waveform"):
raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
waveform = output.waveform.squeeze().cpu().numpy()
sample_rate = 16000
return sample_rate, waveform
# ------------------------------------------------------
# 7. Coqui TTS Helper for Chinese and Japanese
# ------------------------------------------------------
def load_coqui_tts():
global coqui_tts_cache
if coqui_tts_cache is not None:
return coqui_tts_cache
try:
# Set gpu=True if a GPU is available.
coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
except Exception as e:
raise RuntimeError(f"Failed to load Coqui XTTS-v2 TTS: {e}")
return coqui_tts_cache
def run_coqui_tts(text, lang):
coqui_tts = load_coqui_tts()
lang_code = coqui_lang_map[lang] # "zh" for Chinese or "ja" for Japanese
# Write the output to a temporary file and then read it back.
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_name = tmp.name
try:
coqui_tts.tts_to_file(
text=text,
file_path=tmp_name,
language=lang_code # using default voice; for cloning, add speaker_wav parameter
)
data, sr = sf.read(tmp_name)
finally:
if os.path.exists(tmp_name):
os.remove(tmp_name)
return sr, data
# ------------------------------------------------------
# 8. Main Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
"""
1. Obtain English text (via ASR if audio provided, else text).
2. Translate English text to target_language.
3. Generate TTS audio using either MMS TTS (VITS) or Coqui XTTS-v2.
"""
# Step 1: Get English text.
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
asr_input = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(asr_input)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translate.
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS.
try:
tts_type = tts_config[target_language]["type"]
if tts_type == "mms":
sr, waveform = run_mms_tts(translated_text, target_language)
elif tts_type == "coqui":
sr, waveform = run_coqui_tts(translated_text, target_language)
else:
raise RuntimeError("Unknown TTS type for target language.")
except Exception as e:
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, (sr, waveform)
# ------------------------------------------------------
# 9. Gradio Interface
# ------------------------------------------------------
language_choices = [
"Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
]
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=language_choices, value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech")
],
title="Multimodal Language Learning Aid",
description=(
"This app performs the following steps:\n"
"1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
"2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
"3. Synthesizes speech:\n"
" - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: "
"uses Facebook MMS TTS (VITS-based).\n"
" - For Chinese and Japanese: uses Coqui XTTS-v2.\n"
"\nSelect your target language from the dropdown."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)