Yilin0601's picture
Update app.py
5fb2e7c verified
raw
history blame
6.75 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import pipeline, VitsModel, AutoTokenizer
import scipy # if needed for processing
# ------------------------------------------------------
# 1. ASR Pipeline (English) using Wav2Vec2
# ------------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# ------------------------------------------------------
# 2. Translation Models (3 languages)
# ------------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Japanese": "Helsinki-NLP/opus-mt-en-ja"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"Chinese": "translation_en_to_zh",
"Japanese": "translation_en_to_ja"
}
# ------------------------------------------------------
# 3. TTS Model Configurations
# - Spanish: facebook/mms-tts-spa
# - Chinese: myshell-ai/MeloTTS-Chinese
# - Japanese: myshell-ai/MeloTTS-Japanese
# ------------------------------------------------------
tts_config = {
"Spanish": {
"model_id": "facebook/mms-tts-spa",
"architecture": "vits"
},
"Chinese": {
"model_id": "myshell-ai/MeloTTS-Chinese",
"architecture": "vits"
},
"Japanese": {
"model_id": "myshell-ai/MeloTTS-Japanese",
"architecture": "vits"
}
}
# ------------------------------------------------------
# 4. Caches
# ------------------------------------------------------
translator_cache = {}
tts_model_cache = {} # store (model, tokenizer, architecture)
# ------------------------------------------------------
# 5. Translator Helper
# ------------------------------------------------------
def get_translator(lang):
if lang in translator_cache:
return translator_cache[lang]
model_name = translation_models[lang]
task_name = translation_tasks[lang]
translator = pipeline(task_name, model=model_name)
translator_cache[lang] = translator
return translator
# ------------------------------------------------------
# 6. TTS Loading Helper
# ------------------------------------------------------
def get_tts_model(lang):
"""
Loads (model, tokenizer, architecture) from Hugging Face once, then caches.
"""
if lang in tts_model_cache:
return tts_model_cache[lang]
config = tts_config.get(lang)
if config is None:
raise ValueError(f"No TTS config found for language: {lang}")
model_id = config["model_id"]
arch = config["architecture"]
try:
# Attempt VITS-based loading
model = VitsModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
raise RuntimeError(f"Failed to load TTS model {model_id}: {e}")
tts_model_cache[lang] = (model, tokenizer, arch)
return tts_model_cache[lang]
# ------------------------------------------------------
# 7. TTS Inference Helper
# ------------------------------------------------------
def run_tts_inference(lang, text):
"""
Generates waveform using the loaded TTS model and tokenizer.
Returns (sample_rate, np_array).
"""
model, tokenizer, arch = get_tts_model(lang)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
# VitsModel output is typically `.waveform`
if not hasattr(output, "waveform"):
raise RuntimeError("TTS model output does not contain 'waveform' attribute.")
waveform_tensor = output.waveform
waveform = waveform_tensor.squeeze().cpu().numpy()
# Typically 16 kHz for these VITS models
sample_rate = 16000
return (sample_rate, waveform)
# ------------------------------------------------------
# 8. Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
"""
1. Obtain English text (ASR with Wav2Vec2 or text input).
2. Translate English -> target_language.
3. TTS for that language (using configured models).
"""
# Step 1: English text
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
# Convert to float32 if needed
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Stereo -> mono if needed
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16k if needed
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
asr_input = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(asr_input)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translate
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS
try:
sample_rate, waveform = run_tts_inference(target_language, translated_text)
except Exception as e:
# Return error info in place of audio
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, (sample_rate, waveform)
# ------------------------------------------------------
# 9. Gradio Interface
# ------------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech")
],
title="Multimodal Language Learning Aid",
description=(
"1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
"2. Translates to Spanish, Chinese, or Japanese (Helsinki-NLP models).\n"
"3. Provides synthetic speech with TTS models.\n"
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)