Yilin0601's picture
Update app.py
e2fc711 verified
raw
history blame
7.2 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import pipeline, VitsModel, AutoTokenizer
import scipy # if needed for processing
# ------------------------------------------------------
# 1. ASR Pipeline (English)
# ------------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# ------------------------------------------------------
# 2. Translation Models (3 languages)
# ------------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Japanese": "Helsinki-NLP/opus-mt-en-ja"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"Chinese": "translation_en_to_zh",
"Japanese": "translation_en_to_ja"
}
# ------------------------------------------------------
# 3. TTS Model Configurations
# NOTE: MMS does not provide a Mandarin TTS model,
# so we skip TTS for Chinese.
# ------------------------------------------------------
tts_config = {
"Spanish": {
"model_id": "facebook/mms-tts-spa", # MMS Spanish
"architecture": "vits"
},
"Chinese": None, # No MMS TTS for Chinese
"Japanese": {
"model_id": "facebook/mms-tts-jpn", # MMS Japanese
"architecture": "vits"
}
}
# ------------------------------------------------------
# 4. Caches
# ------------------------------------------------------
translator_cache = {}
tts_model_cache = {} # store (model, tokenizer, architecture)
# ------------------------------------------------------
# 5. Translator Helper
# ------------------------------------------------------
def get_translator(lang):
if lang in translator_cache:
return translator_cache[lang]
model_name = translation_models[lang]
task_name = translation_tasks[lang]
translator = pipeline(task_name, model=model_name)
translator_cache[lang] = translator
return translator
# ------------------------------------------------------
# 6. TTS Loading Helper
# ------------------------------------------------------
def get_tts_model(lang):
"""
Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
If no config is found (e.g. for Chinese), raises ValueError.
"""
if lang in tts_model_cache:
return tts_model_cache[lang]
config = tts_config.get(lang)
if config is None:
# No TTS model for this language
raise ValueError(f"No TTS config found for language: {lang}")
model_id = config["model_id"]
arch = config["architecture"]
try:
# Since arch == "vits" for these examples, load VitsModel + AutoTokenizer
model = VitsModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
tts_model_cache[lang] = (model, tokenizer, arch)
return tts_model_cache[lang]
# ------------------------------------------------------
# 7. TTS Inference Helper
# ------------------------------------------------------
def run_tts_inference(lang, text):
"""
Generates waveform using the loaded TTS model and tokenizer.
Returns (sample_rate, np_array).
"""
model, tokenizer, arch = get_tts_model(lang)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
# VitsModel output is typically `.waveform`
if hasattr(output, "waveform"):
waveform_tensor = output.waveform
else:
raise RuntimeError("TTS model output does not contain 'waveform'.")
# Convert to numpy
waveform = waveform_tensor.squeeze().cpu().numpy()
# MMS TTS typically uses 16 kHz
sample_rate = 16000
return (sample_rate, waveform)
# ------------------------------------------------------
# 8. Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
"""
1. Obtain English text (from text input or ASR).
2. Translate English -> target_language.
3. Run VITS-based TTS for that language (if available).
"""
# Step 1: English text
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
# Convert to float32
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Convert stereo to mono if needed
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16k if needed
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
asr_input = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(asr_input)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translation
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS (skip if no config for language)
try:
if tts_config[target_language] is None:
# No TTS model for Chinese or not supported
return english_text, translated_text, None
sample_rate, waveform = run_tts_inference(target_language, translated_text)
except Exception as e:
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, (sample_rate, waveform)
# ------------------------------------------------------
# 9. Gradio Interface
# ------------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech (if available)")
],
title="Multimodal Language Learning Aid (MMS TTS / VITS)",
description=(
"This app:\n"
"1. Transcribes English speech (via ASR) or accepts English text.\n"
"2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP).\n"
"3. Synthesizes speech with VITS-based MMS TTS models for Spanish/Japanese.\n\n"
"Note: MMS does NOT currently provide a Mandarin TTS model, so TTS is skipped for Chinese."
),
allow_flagging="never"
)
if __name__ == "__main__":
# If running locally, uncomment:
# iface.launch()
iface.launch(server_name="0.0.0.0", server_port=7860)