File size: 3,347 Bytes
18b21ee
e888ead
 
18b21ee
 
 
 
e888ead
18b21ee
 
 
 
e888ead
18b21ee
e888ead
18b21ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e888ead
18b21ee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import gradio as gr
import torch
import spaces
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
import time

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Initialize the ASR model
@spaces.GPU
def load_model():
    print("Loading ASR model...")
    # Load the NVIDIA NeMo ASR model
    model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
    # Move model to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
    print(f"Model loaded on device: {model.device}")
    return model

# Global variable to store the model
model = load_model()

def transcribe(audio, state=""):
    """
    Transcribe audio in real-time
    """
    # Skip processing if no audio is provided
    if audio is None:
        return state, state

    # Get the sample rate from the audio
    sample_rate = 16000  # Default to 16kHz if not specified
    
    # Process the audio with the ASR model
    with torch.no_grad():
        transcription = model.transcribe([audio])[0]
    
    # Append new transcription to the state
    if state == "":
        new_state = transcription
    else:
        new_state = state + " " + transcription
    
    return new_state, new_state

# Define the Gradio interface
with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
    gr.Markdown("# ๐ŸŽ™๏ธ Real-time Speech-to-Text Transcription")
    gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model")
    
    with gr.Row():
        with gr.Column(scale=2):
            audio_input = gr.Audio(
                source="microphone", 
                type="numpy", 
                streaming=True,
                label="Speak into your microphone"
            )
            
            clear_btn = gr.Button("Clear Transcript")
            
        with gr.Column(scale=3):
            text_output = gr.Textbox(
                label="Transcription", 
                placeholder="Your speech will appear here...",
                lines=10
            )
            streaming_text = gr.Textbox(
                label="Real-time Transcription", 
                placeholder="Real-time results will appear here...",
                lines=2
            )
    
    # State to store the ongoing transcription
    state = gr.State("")
    
    # Handle the audio stream
    audio_input.stream(
        fn=transcribe,
        inputs=[audio_input, state],
        outputs=[state, streaming_text],
    )
    
    # Clear the transcription
    def clear_transcription():
        return "", "", ""
    
    clear_btn.click(
        fn=clear_transcription,
        inputs=[],
        outputs=[text_output, streaming_text, state]
    )
    
    # Update the main text output when the state changes
    state.change(
        fn=lambda s: s,
        inputs=[state],
        outputs=[text_output]
    )
    
    gr.Markdown("## ๐Ÿ“ Instructions")
    gr.Markdown("""
    1. Click the microphone button to start recording
    2. Speak clearly into your microphone
    3. The transcription will appear in real-time
    4. Click 'Clear Transcript' to start a new transcription
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()