tdurzynski's picture
Update app.py
f4f0bbf verified
raw
history blame
6.06 kB
"""
Real-time Speech Translation Demo
This demo performs the following:
1. Accepts a 15-second audio recording from the microphone.
2. Uses OpenAI’s Whisper model to transcribe the speech.
3. Splits the transcription into segments (each roughly corresponding to a sentence).
4. Translates each segment on-the-fly using Facebook’s M2M100 model (via Hugging Face Transformers).
5. Streams the cumulative translation output to the user.
Make sure to install all dependencies from requirements.txt.
"""
import gradio as gr
import whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
# -----------------------------------------------------------------------------
# Global Model Loading
# -----------------------------------------------------------------------------
# Load the Whisper model (using the "base" model for a balance between speed and accuracy).
# Note: Loading models may take a few seconds on startup.
whisper_model = whisper.load_model("base") # You can choose a larger model if desired
# Load the M2M100 model and tokenizer for translation.
# The "facebook/m2m100_418M" model supports translation between many languages.
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
m2m100_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
# -----------------------------------------------------------------------------
# Define Supported Languages
# -----------------------------------------------------------------------------
# We define a mapping from display names to language codes used by M2M100.
# (For a full list of supported languages see the M2M100 docs.)
LANGUAGES = {
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
"Chinese": "zh"
}
# -----------------------------------------------------------------------------
# Main Processing Function
# -----------------------------------------------------------------------------
def translate_audio(audio, target_language):
"""
Process the input audio, transcribe it using Whisper, and translate each segment
to the chosen target language. Yields a cumulative translation string for streaming.
Parameters:
audio (str): Path to the recorded audio file.
target_language (str): Display name of the target language (e.g., "English").
Yields:
str: The cumulative translated text after processing each segment.
"""
if audio is None:
yield "No audio provided."
return
# Transcribe the audio file using Whisper.
# Using fp16=False to ensure compatibility on CPUs.
result = whisper_model.transcribe(audio, fp16=False)
# Extract the detected source language from the transcription result.
# (Whisper returns a language code, for example "en" for English.)
source_lang = result.get("language", "en")
# Get the target language code from our mapping; default to English if not found.
target_lang_code = LANGUAGES.get(target_language, "en")
cumulative_translation = ""
# Iterate over each segment from the transcription.
# Each segment is a dict with keys such as "start", "end", and "text".
for segment in result.get("segments", []):
# Clean up the segment text.
segment_text = segment.get("text", "").strip()
if segment_text == "":
continue
# If the source and target languages are the same, no translation is needed.
if source_lang == target_lang_code:
translated_segment = segment_text
else:
# Set the tokenizer's source language for proper translation.
tokenizer.src_lang = source_lang
# Tokenize the segment text.
encoded = tokenizer(segment_text, return_tensors="pt")
# Generate translation tokens.
# The 'forced_bos_token_id' parameter forces the model to generate text in the target language.
generated_tokens = m2m100_model.generate(
**encoded,
forced_bos_token_id=tokenizer.get_lang_id(target_lang_code)
)
# Decode the tokens to obtain the translated text.
translated_segment = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Append the new translation segment to the cumulative output.
cumulative_translation += translated_segment + " "
# Yield the updated cumulative translation to simulate streaming output.
yield cumulative_translation.strip()
# -----------------------------------------------------------------------------
# Gradio Interface Definition
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Real-time Speech Translation Demo")
gr.Markdown(
"Speak into the microphone and your speech will be transcribed and translated "
"segment-by-segment. (Recording is limited to 15 seconds.)"
)
with gr.Row():
# Audio input: records from the microphone.
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record your speech (max 15 seconds)",
elem_id="audio_input"
)
# Dropdown to select the target language.
target_lang_dropdown = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="English",
label="Select Target Language"
)
# Output textbox for displaying the (streaming) translation.
output_text = gr.Textbox(label="Translated Text", lines=10)
# Connect the audio input and dropdown to our translation function.
# Since translate_audio is a generator (it yields partial results), Gradio will stream the output.
audio_input.change(
fn=translate_audio,
inputs=[audio_input, target_lang_dropdown],
outputs=output_text
)
# Launch the Gradio app (suitable for Hugging Face Spaces).
demo.launch()