|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import threading |
|
import queue |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import sounddevice as sd |
|
import tempfile |
|
import wave |
|
|
|
|
|
model_name = "Futuresony/Future-sw_ASR-24-02-2025" |
|
processor = Wav2Vec2Processor.from_pretrained(model_name) |
|
model = Wav2Vec2ForCTC.from_pretrained(model_name) |
|
|
|
|
|
q = queue.Queue() |
|
streaming = True |
|
|
|
|
|
def callback(indata, frames, time, status): |
|
if status: |
|
print(status) |
|
q.put(indata.copy()) |
|
|
|
|
|
def transcribe_stream(): |
|
global streaming |
|
samplerate = 16000 |
|
|
|
|
|
with sd.InputStream(samplerate=samplerate, channels=1, callback=callback): |
|
while streaming: |
|
audio_data = [] |
|
|
|
try: |
|
|
|
for _ in range(5): |
|
audio_chunk = q.get(timeout=1) |
|
audio_data.append(audio_chunk) |
|
|
|
|
|
audio_np = np.concatenate(audio_data, axis=0).flatten() |
|
|
|
|
|
input_values = processor(audio_np, sampling_rate=16000, return_tensors="pt").input_values |
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0] |
|
|
|
yield transcription |
|
|
|
except queue.Empty: |
|
continue |
|
|
|
|
|
def live_transcription(): |
|
return transcribe_stream() |
|
|
|
interface = gr.Interface( |
|
fn=live_transcription, |
|
inputs=None, |
|
outputs=gr.Textbox(label="Live Transcription"), |
|
live=True, |
|
title="Swahili Live Streaming ASR", |
|
description="Speak continuously, and the subtitles will appear in real-time.", |
|
) |
|
|
|
|
|
thread = threading.Thread(target=transcribe_stream) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|