Yilin0601's picture
Update app.py
178dac1 verified
raw
history blame
8.33 kB
import gradio as gr
import torch
import numpy as np
import librosa
import soundfile as sf
import tempfile
import os
from transformers import (
pipeline,
VitsModel,
AutoTokenizer
)
# For Coqui TTS
try:
from TTS.api import TTS as CoquiTTS
except ImportError:
raise ImportError("Please install Coqui TTS via `pip install TTS`.")
# ------------------------------------------------------
# 1. ASR Pipeline (English) using Wav2Vec2
# ------------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# ------------------------------------------------------
# 2. Translation Models (3 languages)
# ------------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Japanese": "Helsinki-NLP/opus-mt-en-ja"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"Chinese": "translation_en_to_zh",
"Japanese": "translation_en_to_ja"
}
# ------------------------------------------------------
# 3. TTS Config:
# - Spanish: MMS TTS (facebook/mms-tts-spa)
# - Chinese, Japanese: Coqui XTTS-v2 (tts_models/multilingual/multi-dataset/xtts_v2)
# ------------------------------------------------------
SPANISH = "Spanish"
CHINESE = "Chinese"
JAPANESE = "Japanese"
# For Spanish (MMS)
mms_spanish_config = {
"model_id": "facebook/mms-tts-spa",
"architecture": "vits"
}
# We'll map Chinese/Japanese to Coqui language codes
coqui_lang_map = {
CHINESE: "zh",
JAPANESE: "ja"
}
# ------------------------------------------------------
# 4. Global Caches
# ------------------------------------------------------
translator_cache = {}
spanish_vits_cache = None
coqui_tts_cache = None
def get_translator(lang):
"""
Return a cached MarianMT translator for the specified language.
"""
if lang in translator_cache:
return translator_cache[lang]
model_name = translation_models[lang]
task_name = translation_tasks[lang]
translator = pipeline(task_name, model=model_name)
translator_cache[lang] = translator
return translator
# ------------------------------------------------------
# 5. Spanish TTS: MMS (VITS)
# ------------------------------------------------------
def load_spanish_vits():
"""
Load and cache the Spanish MMS TTS model (VITS).
"""
global spanish_vits_cache
if spanish_vits_cache is not None:
return spanish_vits_cache
try:
model = VitsModel.from_pretrained(mms_spanish_config["model_id"])
tokenizer = AutoTokenizer.from_pretrained(mms_spanish_config["model_id"])
spanish_vits_cache = (model, tokenizer)
except Exception as e:
raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}")
return spanish_vits_cache
def run_spanish_tts(text):
"""
Run MMS TTS (VITS) for Spanish text.
Returns (sample_rate, waveform).
"""
model, tokenizer = load_spanish_vits()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
if not hasattr(output, "waveform"):
raise RuntimeError("Spanish TTS model output does not contain 'waveform'.")
waveform = output.waveform.squeeze().cpu().numpy()
sample_rate = 16000
return sample_rate, waveform
# ------------------------------------------------------
# 6. Chinese/Japanese TTS: Coqui XTTS-v2
# ------------------------------------------------------
def load_coqui_tts():
"""
Load and cache the Coqui XTTS-v2 model (multilingual).
"""
global coqui_tts_cache
if coqui_tts_cache is not None:
return coqui_tts_cache
try:
# If you have a GPU on HF Spaces, you can set gpu=True.
# If not, set gpu=False to run on CPU (slower).
coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
except Exception as e:
raise RuntimeError("Failed to load Coqui XTTS-v2 TTS: %s" % e)
return coqui_tts_cache
def run_coqui_tts(text, lang):
"""
Run Coqui TTS for Chinese or Japanese text.
We specify the language code from coqui_lang_map.
Returns (sample_rate, waveform).
"""
coqui_tts = load_coqui_tts()
lang_code = coqui_lang_map[lang] # "zh" or "ja"
# We must output to a file, then read it back.
# Use a temporary file to store the wave.
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_name = tmp.name
try:
coqui_tts.tts_to_file(
text=text,
file_path=tmp_name,
language=lang_code # no speaker_wav, default voice
)
data, sr = sf.read(tmp_name)
finally:
# Cleanup the temporary file
if os.path.exists(tmp_name):
os.remove(tmp_name)
return sr, data
# ------------------------------------------------------
# 7. Main Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
"""
1. Get English text (ASR if audio provided, else text).
2. Translate to target_language.
3. TTS with the chosen approach:
- Spanish -> MMS TTS (VITS)
- Chinese/Japanese -> Coqui XTTS-v2
"""
# Step 1: English text
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
# Convert to float32 if needed
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Stereo -> mono
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16k if needed
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
asr_input = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(asr_input)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translate
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS
try:
if target_language == SPANISH:
sr, waveform = run_spanish_tts(translated_text)
else:
# Chinese or Japanese
sr, waveform = run_coqui_tts(translated_text, target_language)
except Exception as e:
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, (sr, waveform)
# ------------------------------------------------------
# 8. Gradio Interface
# ------------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=[SPANISH, CHINESE, JAPANESE], value=SPANISH, label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech")
],
title="Multimodal Language Learning Aid",
description=(
"1. Transcribes English speech using Wav2Vec2 (or takes English text).\n"
"2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP).\n"
"3. Synthesizes speech:\n"
" - Spanish -> facebook/mms-tts-spa (VITS)\n"
" - Chinese & Japanese -> Coqui XTTS-v2 (multilingual TTS)\n\n"
"Note: The Coqui model is 'tts_models/multilingual/multi-dataset/xtts_v2' and expects language codes.\n"
"If you need voice cloning, set `speaker_wav` in `tts_to_file()`. By default, it uses a single generic voice."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)