Yilin0601's picture
Update app.py
cc8959c verified
raw
history blame
8.28 kB
import gradio as gr
import torch
import numpy as np
import librosa
import soundfile as sf # likely needed by the pipeline or local saving
from transformers import pipeline, VitsModel, AutoTokenizer
from datasets import load_dataset
# ------------------------------------------------------
# 1. ASR Pipeline (English) - Wav2Vec2
# ------------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# ------------------------------------------------------
# 2. Translation Models (3 languages)
# ------------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Japanese": "Helsinki-NLP/opus-mt-en-ja"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"Chinese": "translation_en_to_zh",
"Japanese": "translation_en_to_ja"
}
# ------------------------------------------------------
# 3. TTS Configuration
# - Spanish: VITS-based MMS TTS
# - Chinese & Japanese: Microsoft SpeechT5
# ------------------------------------------------------
# We'll store them as keys for convenience
SPANISH_KEY = "Spanish"
CHINESE_KEY = "Chinese"
JAPANESE_KEY = "Japanese"
# VITS config for Spanish only
mms_spanish_config = {
"model_id": "facebook/mms-tts-spa",
"architecture": "vits"
}
# ------------------------------------------------------
# 4. Create TTS Pipelines / Models Once (Caching)
# ------------------------------------------------------
translator_cache = {}
vits_model_cache = None # for Spanish
speech_t5_pipeline_cache = None # for Chinese/Japanese
speech_t5_speaker_embedding = None
def get_translator(lang):
"""
Return a cached MarianMT translator for the specified language.
"""
if lang in translator_cache:
return translator_cache[lang]
model_name = translation_models[lang]
task_name = translation_tasks[lang]
translator = pipeline(task_name, model=model_name)
translator_cache[lang] = translator
return translator
def load_spanish_vits():
"""
Load and cache the Spanish VITS model + tokenizer (facebook/mms-tts-spa).
"""
global vits_model_cache
if vits_model_cache is not None:
return vits_model_cache
try:
model_id = mms_spanish_config["model_id"]
model = VitsModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
vits_model_cache = (model, tokenizer)
except Exception as e:
raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
return vits_model_cache
def load_speech_t5_pipeline():
"""
Load and cache the Microsoft SpeechT5 text-to-speech pipeline
and a default speaker embedding.
"""
global speech_t5_pipeline_cache, speech_t5_speaker_embedding
if speech_t5_pipeline_cache is not None and speech_t5_speaker_embedding is not None:
return speech_t5_pipeline_cache, speech_t5_speaker_embedding
try:
# Create the pipeline
# The pipeline is named "text-to-speech" in Transformers >= 4.29
t5_pipe = pipeline("text-to-speech", model="microsoft/speecht5_tts")
except Exception as e:
raise RuntimeError(f"Failed to load Microsoft SpeechT5 pipeline: {e}")
# Load a default speaker embedding
try:
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
# Just pick an arbitrary index for speaker embedding
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
except Exception as e:
raise RuntimeError(f"Failed to load default speaker embedding: {e}")
speech_t5_pipeline_cache = t5_pipe
speech_t5_speaker_embedding = speaker_embedding
return t5_pipe, speaker_embedding
# ------------------------------------------------------
# 5. TTS Inference Helpers
# ------------------------------------------------------
def run_vits_inference(text):
"""
For Spanish TTS using MMS (facebook/mms-tts-spa).
"""
model, tokenizer = load_spanish_vits()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
if not hasattr(output, "waveform"):
raise RuntimeError("VITS output does not contain 'waveform'.")
waveform = output.waveform.squeeze().cpu().numpy()
sample_rate = 16000
return sample_rate, waveform
def run_speecht5_inference(text):
"""
For Chinese & Japanese TTS using Microsoft SpeechT5 pipeline.
"""
t5_pipe, speaker_embedding = load_speech_t5_pipeline()
# The pipeline returns a dict with 'audio' (numpy) and 'sampling_rate'
result = t5_pipe(
text,
forward_params={"speaker_embeddings": speaker_embedding}
)
waveform = result["audio"]
sample_rate = result["sampling_rate"]
return sample_rate, waveform
# ------------------------------------------------------
# 6. Main Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
"""
1. Get English text (ASR if audio provided, else text).
2. Translate to target_language.
3. TTS with the chosen approach (VITS for Spanish, SpeechT5 for Chinese/Japanese).
"""
# Step 1: English text
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
# Convert to float32 if needed
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Stereo -> mono
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16k if needed
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
asr_input = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(asr_input)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translate
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS
try:
if target_language == SPANISH_KEY:
sr, waveform = run_vits_inference(translated_text)
else:
# Chinese or Japanese -> SpeechT5
sr, waveform = run_speecht5_inference(translated_text)
except Exception as e:
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, (sr, waveform)
# ------------------------------------------------------
# 7. Gradio Interface
# ------------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech")
],
title="Multimodal Language Learning Aid",
description=(
"1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
"2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP models).\n"
"3. Synthesizes speech:\n"
" - Spanish -> facebook/mms-tts-spa (VITS)\n"
" - Chinese & Japanese -> microsoft/speecht5_tts (SpeechT5)\n\n"
"Note: SpeechT5 is not officially trained for Japanese, so results may vary.\n"
"You can also try inputting short, clear audio for best ASR results."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)