Yilin0601's picture
Update app.py
be4098e verified
raw
history blame
6.49 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import pipeline
import scipy # imported if needed for processing
# --------------------------------------------------
# ASR Pipeline (for English transcription)
# --------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h"
)
# --------------------------------------------------
# Mapping for Target Languages and Translation Pipelines
# --------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"French": "Helsinki-NLP/opus-mt-en-fr",
"German": "Helsinki-NLP/opus-mt-en-de",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Russian": "Helsinki-NLP/opus-mt-en-ru",
"Arabic": "Helsinki-NLP/opus-mt-en-ar",
"Portuguese": "Helsinki-NLP/opus-mt-en-pt",
"Japanese": "Helsinki-NLP/opus-mt-en-ja",
"Italian": "Helsinki-NLP/opus-mt-en-it",
"Korean": "Helsinki-NLP/opus-mt-en-ko"
}
translation_tasks = {
"Spanish": "translation_en_to_es",
"French": "translation_en_to_fr",
"German": "translation_en_to_de",
"Chinese": "translation_en_to_zh",
"Russian": "translation_en_to_ru",
"Arabic": "translation_en_to_ar",
"Portuguese": "translation_en_to_pt",
"Japanese": "translation_en_to_ja",
"Italian": "translation_en_to_it",
"Korean": "translation_en_to_ko"
}
# --------------------------------------------------
# TTS Models (using real Facebook MMS TTS & others)
# --------------------------------------------------
tts_models = {
"Spanish": "facebook/mms-tts-spa",
"French": "facebook/mms-tts-fra",
"German": "facebook/mms-tts-deu",
"Chinese": "facebook/mms-tts-che",
"Russian": "facebook/mms-tts-rus",
"Arabic": "facebook/mms-tts-ara",
"Portuguese": "facebook/mms-tts-por",
"Japanese": "esnya/japanese_speecht5_tts",
"Italian": "tts_models/it/tacotron2",
"Korean": "facebook/mms-tts-kor"
}
# --------------------------------------------------
# Caches for translator and TTS pipelines
# --------------------------------------------------
translator_cache = {}
tts_cache = {}
def get_translator(target_language):
"""
Retrieve or create a translation pipeline for the specified language.
"""
if target_language in translator_cache:
return translator_cache[target_language]
model_name = translation_models[target_language]
task_name = translation_tasks[target_language]
translator = pipeline(task_name, model=model_name)
translator_cache[target_language] = translator
return translator
def get_tts(target_language):
"""
Retrieve or create a TTS pipeline for the specified language.
"""
if target_language in tts_cache:
return tts_cache[target_language]
model_name = tts_models.get(target_language)
if model_name is None:
raise ValueError(f"No TTS model available for {target_language}.")
try:
tts_pipeline = pipeline("text-to-speech", model=model_name)
except Exception as e:
raise ValueError(
f"Failed to load TTS model for {target_language} with model '{model_name}'.\nError: {e}"
)
tts_cache[target_language] = tts_pipeline
return tts_pipeline
# --------------------------------------------------
# Prediction Function
# --------------------------------------------------
def predict(audio, text, target_language):
"""
1. Obtain English text (from text input or ASR).
2. Translate English -> target_language.
3. Synthesize speech in target_language.
"""
# Step 1: Get English text from text input (if provided) or from ASR.
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
input_audio = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(input_audio)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# Step 2: Translation
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
return english_text, f"Translation error: {e}", None
# Step 3: TTS synthesis using Facebook MMS TTS (or alternative) pipeline.
try:
tts_pipeline = get_tts(target_language)
tts_result = tts_pipeline(translated_text)
# Expected output: a dict with "wav" and "sample_rate"
synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
except Exception as e:
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, synthesized_audio
# --------------------------------------------------
# Gradio Interface Setup
# --------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=list(translation_models.keys()), value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech in Target Language")
],
title="Multimodal Language Learning Aid",
description=(
"This app provides three outputs:\n"
"1. English transcription (from ASR or text input),\n"
"2. Translation to a target language (using Helsinki-NLP models), and\n"
"3. Synthetic speech in the target language (using Facebook MMS TTS or equivalent).\n\n"
"Select one of the top 10 commonly used languages from the dropdown.\n"
"Either record/upload an English audio sample or enter English text directly."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()