Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import torch | |
import nemo.collections.asr as nemo_asr | |
from omegaconf import OmegaConf | |
import time | |
import spaces | |
import librosa | |
# Important: Don't initialize CUDA in the main process for Spaces | |
# The model will be loaded in the worker process through the GPU decorator | |
model = None | |
def load_model(): | |
# This function will be called in the GPU worker process | |
global model | |
if model is None: | |
print(f"Loading model in worker process") | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
print(f"CUDA device: {torch.cuda.get_device_name(0)}") | |
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") | |
print(f"Model loaded on device: {model.device}") | |
return model | |
def transcribe(audio, state="", audio_buffer=None, last_processed_time=0): | |
# Load the model inside the GPU worker process | |
import numpy as np | |
import soundfile as sf | |
import librosa | |
import os | |
model = load_model() | |
if audio_buffer is None: | |
audio_buffer = [] | |
if audio is None or isinstance(audio, int): | |
print(f"Skipping invalid audio input: {type(audio)}") | |
return state, state, audio_buffer, last_processed_time | |
print(f"Received audio input of type: {type(audio)}") | |
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray): | |
sample_rate, audio_data = audio | |
print(f"Sample rate: {sample_rate}, Audio shape: {audio_data.shape}") | |
# Append chunk to buffer | |
audio_buffer.append(audio_data) | |
# Calculate total duration in seconds | |
total_samples = sum(arr.shape[0] for arr in audio_buffer) | |
total_duration = total_samples / sample_rate | |
print(f"Total buffered duration: {total_duration:.2f}s") | |
# Process 3-second chunks with 1-second step size (2-second overlap) | |
chunk_duration = 3.0 # seconds | |
step_size = 1.0 # seconds | |
min_samples = int(chunk_duration * 16000) # 3s at 16kHz | |
if total_duration < chunk_duration: | |
print(f"Buffering audio, total duration: {total_duration:.2f}s") | |
return state, state, audio_buffer, last_processed_time | |
try: | |
# Concatenate buffered chunks | |
full_audio = np.concatenate(audio_buffer) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
print(f"Resampling from {sample_rate}Hz to 16000Hz") | |
full_audio = librosa.resample(full_audio.astype(float), orig_sr=sample_rate, target_sr=16000) | |
sample_rate = 16000 | |
else: | |
full_audio = full_audio.astype(float) | |
# Process 3-second chunks | |
new_state = state | |
current_time = last_processed_time | |
total_samples_16k = len(full_audio) | |
while current_time + chunk_duration <= total_duration: | |
start_sample = int(current_time * sample_rate) | |
end_sample = int((current_time + chunk_duration) * sample_rate) | |
if end_sample > total_samples_16k: | |
break | |
chunk = full_audio[start_sample:end_sample] | |
print(f"Processing chunk from {current_time:.2f}s to {current_time + chunk_duration:.2f}s") | |
# Save to temporary WAV file | |
temp_file = "temp_audio.wav" | |
sf.write(temp_file, chunk, samplerate=16000) | |
# Transcribe | |
hypothesis = model.transcribe([temp_file])[0] | |
transcription = hypothesis.text | |
print(f"Transcription: {transcription}") | |
os.remove(temp_file) | |
print("Temporary file removed.") | |
# Append transcription if non-empty | |
if transcription.strip(): | |
new_state = new_state + " " + transcription if new_state else transcription | |
current_time += step_size | |
# Update last processed time | |
last_processed_time = current_time | |
# Trim buffer to keep only unprocessed audio | |
keep_samples = int((total_duration - current_time) * sample_rate) | |
if keep_samples > 0: | |
audio_buffer = [full_audio[-keep_samples:]] | |
else: | |
audio_buffer = [] | |
print(f"New state: {new_state}") | |
return new_state, new_state, audio_buffer, last_processed_time | |
except Exception as e: | |
print(f"Error processing audio: {e}") | |
return state, state, audio_buffer, last_processed_time | |
print(f"Invalid audio input format: {type(audio)}") | |
return state, state, audio_buffer, last_processed_time | |
# Define the Gradio interface | |
with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo: | |
gr.Markdown("# ποΈ Real-time Speech-to-Text Transcription") | |
gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
type="numpy", | |
streaming=True, | |
label="Speak into your microphone" | |
) | |
clear_btn = gr.Button("Clear Transcript") | |
with gr.Column(scale=3): | |
text_output = gr.Textbox( | |
label="Transcription", | |
placeholder="Your speech will appear here...", | |
lines=10 | |
) | |
streaming_text = gr.Textbox( | |
label="Real-time Transcription", | |
placeholder="Real-time results will appear here...", | |
lines=2 | |
) | |
# State to store the ongoing transcription | |
state = gr.State("") | |
audio_buffer = gr.State(value=None) | |
last_processed_time = gr.State(value=0) | |
# Handle the audio stream | |
audio_input.stream( | |
fn=transcribe, | |
inputs=[audio_input, state, audio_buffer, last_processed_time], | |
outputs=[state, streaming_text, audio_buffer, last_processed_time], | |
) | |
# Clear the transcription | |
def clear_transcription(): | |
return "", "", None, 0 | |
clear_btn.click( | |
fn=clear_transcription, | |
inputs=[], | |
outputs=[text_output, streaming_text, audio_buffer, last_processed_time] | |
) | |
# Update the main text output when the state changes | |
state.change( | |
fn=lambda s: s, | |
inputs=[state], | |
outputs=[text_output] | |
) | |
gr.Markdown("## π Instructions") | |
gr.Markdown(""" | |
1. Click the microphone button to start recording | |
2. Speak clearly into your microphone | |
3. The transcription will appear in real-time | |
4. Click 'Clear Transcript' to start a new transcription | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |