File size: 12,398 Bytes
a4181e3 0206ee8 a4181e3 50babed a4181e3 1027960 a4181e3 1027960 286e4ee 2b054c9 50babed 2b054c9 50babed 2b054c9 a4181e3 2b054c9 a4181e3 2b054c9 a4181e3 0206ee8 15b3060 50babed 15b3060 a4181e3 0206ee8 a4181e3 1027960 06526ee 2b054c9 a4181e3 1027960 50babed 1027960 50babed 1027960 a4181e3 1027960 a4181e3 06526ee a4181e3 d1c4428 a4181e3 50babed a4181e3 d1c4428 a4181e3 50babed a4181e3 d1c4428 a4181e3 50babed 06526ee a4181e3 50babed a4181e3 2b054c9 a4181e3 1027960 c8231f4 1027960 c8231f4 d1c4428 c8231f4 d1c4428 a4181e3 2b054c9 c8231f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
import time
import asyncio
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
# Import your model and VAD libraries.
from silero_vad import VADIterator, load_silero_vad
from transformers import AutoProcessor, pipeline
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
processor = AutoProcessor.from_pretrained("optimum/whisper-tiny.en")
model = ORTModelForSpeechSeq2Seq.from_pretrained("optimum/whisper-tiny.en")
speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)
# Constants
SAMPLING_RATE = 16000
CHUNK_SIZE = 512 # Required for Silero VAD at 16kHz.
LOOKBACK_CHUNKS = 5
MAX_SPEECH_SECS = 15 # Maximum duration for a single transcription segment.
MIN_REFRESH_SECS = 1 # Minimum interval for sending partial updates.
app = FastAPI()
# class Transcriber:
# def __init__(self, model_name: str, rate: int = 16000):
# if rate != 16000:
# raise ValueError("Moonshine supports sampling rate 16000 Hz.")
# self.model = MoonshineOnnxModel(model_name=model_name)
# self.rate = rate
# self.tokenizer = load_tokenizer()
# # Statistics (optional)
# self.inference_secs = 0
# self.number_inferences = 0
# self.speech_secs = 0
# # Warmup run.
# self.__call__(np.zeros(int(rate), dtype=np.float32))
# def __call__(self, speech: np.ndarray) -> str:
# """Returns a transcription of the given speech (a float32 numpy array)."""
# self.number_inferences += 1
# self.speech_secs += len(speech) / self.rate
# start_time = time.time()
# tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
# text = self.tokenizer.decode_batch(tokens)[0]
# self.inference_secs += time.time() - start_time
# return text
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
"""
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
"""
int_data = np.frombuffer(pcm_data, dtype=np.int16)
float_data = int_data.astype(np.float32) / 32768.0
return float_data
# Initialize models.
# model_name_tiny = "moonshine/tiny"
# model_name_base = "moonshine/base"
# transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
# transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(
model=vad_model,
sampling_rate=SAMPLING_RATE,
threshold=0.5,
min_silence_duration_ms=300,
)
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
caption_cache = []
lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
speech = np.empty(0, dtype=np.float32)
recording = False
last_partial_time = time.time()
current_model = transcriber_tiny # Default to tiny model
last_output = ""
try:
while True:
data = await websocket.receive()
if data["type"] == "websocket.receive":
if data.get("text") == "switch_to_tiny":
# current_model = transcriber_tiny
continue
elif data.get("text") == "switch_to_base":
# current_model = transcriber_base
continue
chunk = pcm16_to_float32(data["bytes"])
speech = np.concatenate((speech, chunk))
if not recording:
speech = speech[-lookback_size:]
vad_result = vad_iterator(chunk)
current_time = time.time()
if vad_result:
if "start" in vad_result and not recording:
recording = True
await websocket.send_json({"type": "status", "message": "speaking_started"})
if "end" in vad_result and recording:
recording = False
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = np.empty(0, dtype=np.float32)
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
elif recording:
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
recording = False
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = np.empty(0, dtype=np.float32)
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
if (current_time - last_partial_time) > MIN_REFRESH_SECS:
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
if last_output != text:
last_output = text
await websocket.send_json({"type": "partial", "transcript": text})
last_partial_time = current_time
except WebSocketDisconnect:
if recording and speech.size:
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
print("WebSocket disconnected")
@app.get("/", response_class=HTMLResponse)
async def get_home():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>AssemblyAI Realtime Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-100 p-6">
<div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1>
<button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button>
<select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4">
<option value="tiny">Tiny Model</option>
<option value="base">Base Model</option>
</select>
<p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p>
<p id="speakingStatus" class="text-gray-600 mb-4"></p>
<div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div>
<div id="visualizer" class="border p-4 rounded h-64">
<canvas id="audioCanvas" class="w-full h-full"></canvas>
</div>
</div>
<script>
let ws;
let audioContext;
let scriptProcessor;
let mediaStream;
let currentLine = document.createElement('span');
let analyser;
let canvas, canvasContext;
document.getElementById('transcription').appendChild(currentLine);
canvas = document.getElementById('audioCanvas');
canvasContext = canvas.getContext('2d');
async function startTranscription() {
document.getElementById("status").innerText = "Connecting...";
ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
ws.binaryType = 'arraybuffer';
ws.onopen = async function() {
document.getElementById("status").innerText = "Connected";
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new AudioContext({ sampleRate: 16000 });
const source = audioContext.createMediaStreamSource(mediaStream);
analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
source.connect(analyser);
scriptProcessor = audioContext.createScriptProcessor(512, 1, 1);
scriptProcessor.onaudioprocess = function(event) {
const inputData = event.inputBuffer.getChannelData(0);
const pcm16 = floatTo16BitPCM(inputData);
if (ws.readyState === WebSocket.OPEN) {
ws.send(pcm16);
}
analyser.getByteTimeDomainData(dataArray);
canvasContext.fillStyle = 'rgb(200, 200, 200)';
canvasContext.fillRect(0, 0, canvas.width, canvas.height);
canvasContext.lineWidth = 2;
canvasContext.strokeStyle = 'rgb(0, 0, 0)';
canvasContext.beginPath();
let sliceWidth = canvas.width * 1.0 / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
let v = dataArray[i] / 128.0;
let y = v * canvas.height / 2;
if (i === 0) {
canvasContext.moveTo(x, y);
} else {
canvasContext.lineTo(x, y);
}
x += sliceWidth;
}
canvasContext.lineTo(canvas.width, canvas.height / 2);
canvasContext.stroke();
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
document.getElementById("status").innerText = "Error: " + err;
}
};
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
if (data.type === 'partial') {
currentLine.style.color = 'gray';
currentLine.textContent = data.transcript + ' ';
} else if (data.type === 'final') {
currentLine.style.color = 'black';
currentLine.textContent = data.transcript;
currentLine = document.createElement('span');
document.getElementById('transcription').appendChild(document.createElement('br'));
document.getElementById('transcription').appendChild(currentLine);
} else if (data.type === 'status') {
if (data.message === 'speaking_started') {
document.getElementById("speakingStatus").innerText = "Speaking Started";
document.getElementById("speakingStatus").style.color = "green";
} else if (data.message === 'speaking_stopped') {
document.getElementById("speakingStatus").innerText = "Speaking Stopped";
document.getElementById("speakingStatus").style.color = "red";
}
}
};
ws.onclose = function() {
if (audioContext && audioContext.state !== 'closed') {
audioContext.close();
}
document.getElementById("status").innerText = "Closed";
};
}
function switchModel() {
const model = document.getElementById("modelSelect").value;
if (ws && ws.readyState === WebSocket.OPEN) {
if (model === "tiny") {
ws.send("switch_to_tiny");
} else if (model === "base") {
ws.send("switch_to_base");
}
}
}
function floatTo16BitPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const output = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
let s = Math.max(-1, Math.min(1, input[i]));
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|