WhisperLiveKitDiarization / whisper_fastapi_online_server.py
qfuxa's picture
new buffer format
d98de94
raw
history blame
10.8 kB
import io
import argparse
import asyncio
import numpy as np
import ffmpeg
from time import time, sleep
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
from src.whisper_streaming.timed_objects import ASRToken
import math
import logging
from datetime import timedelta
def format_time(seconds):
return str(timedelta(seconds=int(seconds)))
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
##### LOAD ARGS #####
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
dest="warmup_file",
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
)
parser.add_argument(
"--diarization",
type=bool,
default=True,
help="Whether to enable speaker diarization.",
)
parser.add_argument(
"--transcription",
type=bool,
default=True,
help="To disable to only see live diarization results.",
)
add_shared_args(parser)
args = parser.parse_args()
SAMPLE_RATE = 16000
CHANNELS = 1
SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
##### LOAD APP #####
@asynccontextmanager
async def lifespan(app: FastAPI):
global asr, tokenizer, diarization
if args.transcription:
asr, tokenizer = backend_factory(args)
else:
asr, tokenizer = None, None
if args.diarization:
from src.diarization.diarization_online import DiartDiarization
diarization = DiartDiarization(SAMPLE_RATE)
else :
diarization = None
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load demo HTML for the root endpoint
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
html = f.read()
async def start_ffmpeg_decoder():
"""
Start an FFmpeg process in async streaming mode that reads WebM from stdin
and outputs raw s16le PCM on stdout. Returns the process object.
"""
process = (
ffmpeg.input("pipe:0", format="webm")
.output(
"pipe:1",
format="s16le",
acodec="pcm_s16le",
ac=CHANNELS,
ar=str(SAMPLE_RATE),
)
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
)
return process
##### ENDPOINTS #####
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
logger.info("WebSocket connection opened.")
ffmpeg_process = None
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer) if args.transcription else None
async def restart_ffmpeg():
nonlocal ffmpeg_process, online, pcm_buffer
if ffmpeg_process:
try:
ffmpeg_process.kill()
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
except Exception as e:
logger.warning(f"Error killing FFmpeg process: {e}")
ffmpeg_process = await start_ffmpeg_decoder()
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer) if args.transcription else None
logger.info("FFmpeg process started.")
await restart_ffmpeg()
async def ffmpeg_stdout_reader():
nonlocal ffmpeg_process, online, pcm_buffer
loop = asyncio.get_event_loop()
full_transcription = ""
beg = time()
beg_loop = time()
tokens = []
end_attributed_speaker = 0
sep = online.asr.sep
while True:
try:
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096)
beg = time()
# Read chunk with timeout
try:
chunk = await asyncio.wait_for(
loop.run_in_executor(
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
),
timeout=5.0
)
except asyncio.TimeoutError:
logger.warning("FFmpeg read timeout. Restarting...")
await restart_ffmpeg()
full_transcription = ""
beg = time()
continue # Skip processing and read from new process
if not chunk:
logger.info("FFmpeg stdout closed.")
break
pcm_buffer.extend(chunk)
if len(pcm_buffer) >= BYTES_PER_SEC:
if len(pcm_buffer) > MAX_BYTES_PER_SEC:
logger.warning(
f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds.
The model probably struggles to keep up. Consider using a smaller model.
""")
# Convert int16 -> float32
pcm_array = (
np.frombuffer(pcm_buffer[:MAX_BYTES_PER_SEC], dtype=np.int16).astype(np.float32)
/ 32768.0
)
pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
if args.transcription:
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
online.insert_audio_chunk(pcm_array)
new_tokens = online.process_iter()
tokens.extend(new_tokens)
full_transcription += sep.join([t.text for t in new_tokens])
_buffer = online.get_buffer()
buffer = _buffer.text
end_buffer = _buffer.end if _buffer.end else tokens[-1].end if tokens else 0
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
else:
tokens.append(
ASRToken(
start = time() - beg_loop,
end = time() - beg_loop + 0.5))
sleep(0.5)
buffer = ''
if args.diarization:
await diarization.diarize(pcm_array)
end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
previous_speaker = -10
lines = []
last_end_diarized = 0
for token in tokens:
speaker = token.speaker
if args.diarization:
if speaker == -1 or speaker == 0:
if token.end < end_attributed_speaker:
speaker = previous_speaker
else:
speaker = 0
else:
last_end_diarized = max(token.end, last_end_diarized)
if speaker != previous_speaker:
lines.append(
{
"speaker": speaker,
"text": token.text,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
)
previous_speaker = speaker
else:
lines[-1]["text"] += sep + token.text
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
response = {"lines": lines, "buffer": buffer}
# response = {"lines": lines, "buffer": buffer, "time_buffer_transcription": time() + beg_loop - end_buffer, "time_buffer_diarization": time() + beg_loop - end_attributed_speaker}
await websocket.send_json(response)
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
break
logger.info("Exiting ffmpeg_stdout_reader...")
stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
try:
while True:
# Receive incoming WebM audio chunks from the client
message = await websocket.receive_bytes()
try:
ffmpeg_process.stdin.write(message)
ffmpeg_process.stdin.flush()
except (BrokenPipeError, AttributeError) as e:
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
await restart_ffmpeg()
ffmpeg_process.stdin.write(message)
ffmpeg_process.stdin.flush()
except WebSocketDisconnect:
logger.warning("WebSocket disconnected.")
finally:
stdout_reader_task.cancel()
try:
ffmpeg_process.stdin.close()
ffmpeg_process.wait()
except:
pass
if args.diarization:
diarization.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
log_level="info"
)