File size: 5,335 Bytes
b4d8745
5351689
b4d8745
 
5351689
b4d8745
5351689
 
 
 
 
b4d8745
 
 
 
 
 
 
 
 
 
 
 
 
5351689
b4d8745
 
 
 
 
 
 
 
5351689
b4d8745
 
 
 
 
5351689
 
b4d8745
 
 
 
 
 
 
 
5351689
b4d8745
 
 
 
 
5351689
b4d8745
 
 
5351689
b4d8745
 
 
 
 
5351689
b4d8745
 
 
5351689
b4d8745
 
 
 
 
 
 
5351689
b4d8745
 
 
5351689
 
 
 
 
 
 
 
 
 
 
 
b4d8745
 
 
 
 
 
 
 
 
 
 
5351689
b4d8745
f4f0bbf
b4d8745
 
 
 
 
 
 
 
 
 
 
 
 
5351689
 
 
 
b4d8745
 
 
 
 
5351689
 
 
 
 
 
 
b4d8745
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Real-time Speech Translation Demo with Restart Option

This demo performs the following:
  1. Accepts up to 15 seconds of audio recording from the microphone.
  2. Uses OpenAI’s Whisper model to transcribe the speech.
  3. Splits the transcription into segments and translates each segment
     on-the-fly using Facebook’s M2M100 model.
  4. Streams the cumulative translation output to the user.
  5. Provides a "Restart Recording" button that resets the audio input and
     translation output.

Make sure to install all dependencies from requirements.txt.
"""

import gradio as gr
import whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

# -----------------------------------------------------------------------------
# Global Model Loading
# -----------------------------------------------------------------------------
# Load the Whisper model (using the "base" model for a balance between speed and accuracy).
whisper_model = whisper.load_model("base")  # Change model size as needed

# Load the M2M100 model and tokenizer for translation.
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
m2m100_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")

# -----------------------------------------------------------------------------
# Define Supported Languages
# -----------------------------------------------------------------------------
# Added Polish as one of the supported languages.
LANGUAGES = {
    "English": "en",
    "Spanish": "es",
    "French": "fr",
    "German": "de",
    "Chinese": "zh",
    "Polish": "pl"
}

# -----------------------------------------------------------------------------
# Main Processing Function
# -----------------------------------------------------------------------------
def translate_audio(audio, target_language):
    """
    Process the input audio, transcribe it using Whisper, and translate each segment
    to the chosen target language. Yields cumulative translation output for streaming.
    """
    if audio is None:
        yield "No audio provided."
        return

    # Transcribe the audio using Whisper (fp16=False for CPU compatibility)
    result = whisper_model.transcribe(audio, fp16=False)
    source_lang = result.get("language", "en")
    target_lang_code = LANGUAGES.get(target_language, "en")

    cumulative_translation = ""
    for segment in result.get("segments", []):
        segment_text = segment.get("text", "").strip()
        if segment_text == "":
            continue

        if source_lang == target_lang_code:
            translated_segment = segment_text
        else:
            # Set the source language for proper translation.
            tokenizer.src_lang = source_lang
            encoded = tokenizer(segment_text, return_tensors="pt")
            generated_tokens = m2m100_model.generate(
                **encoded,
                forced_bos_token_id=tokenizer.get_lang_id(target_lang_code)
            )
            translated_segment = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

        cumulative_translation += translated_segment + " "
        yield cumulative_translation.strip()

# -----------------------------------------------------------------------------
# Restart Function
# -----------------------------------------------------------------------------
def restart_recording():
    """
    Reset the recording section by clearing the audio input and the translation output.
    Returns:
      - None for the audio input (clearing it)
      - An empty string for the translation textbox.
    """
    return None, ""

# -----------------------------------------------------------------------------
# Gradio Interface Definition
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Real-time Speech Translation Demo")
    gr.Markdown(
        "Speak into the microphone and your speech will be transcribed and translated "
        "segment-by-segment. (Recording is limited to 15 seconds.)"
    )
    
    with gr.Row():
        # Use 'sources' (list) to specify that the microphone is an input source.
        audio_input = gr.Audio(
            sources=["microphone"],
            type="filepath",
            label="Record your speech (max 15 seconds)",
            elem_id="audio_input"
        )
        target_lang_dropdown = gr.Dropdown(
            choices=list(LANGUAGES.keys()),
            value="English",
            label="Select Target Language"
        )
    
    # Output textbox for displaying the (streaming) translation.
    output_text = gr.Textbox(label="Translated Text", lines=10)
    
    # Restart button to clear the current recording and translation.
    restart_button = gr.Button("Restart Recording")
    
    # When new audio is recorded, stream the translation.
    audio_input.change(
        fn=translate_audio,
        inputs=[audio_input, target_lang_dropdown],
        outputs=output_text
    )
    
    # When the restart button is clicked, clear both the audio input and translation output.
    restart_button.click(
        fn=restart_recording,
        inputs=[],
        outputs=[audio_input, output_text]
    )

# Launch the Gradio app (suitable for Hugging Face Spaces).
demo.launch()