|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args |
|
import asyncio |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logging.getLogger().setLevel(logging.WARNING) |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
args = parse_args() |
|
transcription_engine = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
global transcription_engine |
|
transcription_engine = TranscriptionEngine( |
|
**vars(args), |
|
) |
|
yield |
|
|
|
app = FastAPI(lifespan=lifespan) |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
async def get(): |
|
return HTMLResponse(get_web_interface_html()) |
|
|
|
|
|
async def handle_websocket_results(websocket, results_generator): |
|
"""Consumes results from the audio processor and sends them via WebSocket.""" |
|
try: |
|
async for response in results_generator: |
|
await websocket.send_json(response) |
|
|
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.") |
|
await websocket.send_json({"type": "ready_to_stop"}) |
|
except WebSocketDisconnect: |
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).") |
|
except Exception as e: |
|
logger.warning(f"Error in WebSocket results handler: {e}") |
|
|
|
|
|
@app.websocket("/asr") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
global transcription_engine |
|
audio_processor = AudioProcessor( |
|
transcription_engine=transcription_engine, |
|
) |
|
await websocket.accept() |
|
logger.info("WebSocket connection opened.") |
|
|
|
results_generator = await audio_processor.create_tasks() |
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) |
|
|
|
try: |
|
while True: |
|
message = await websocket.receive_bytes() |
|
await audio_processor.process_audio(message) |
|
except KeyError as e: |
|
if 'bytes' in str(e): |
|
logger.warning(f"Client has closed the connection.") |
|
else: |
|
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True) |
|
except WebSocketDisconnect: |
|
logger.info("WebSocket disconnected by client during message receiving loop.") |
|
except Exception as e: |
|
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True) |
|
finally: |
|
logger.info("Cleaning up WebSocket endpoint...") |
|
if not websocket_task.done(): |
|
websocket_task.cancel() |
|
try: |
|
await websocket_task |
|
except asyncio.CancelledError: |
|
logger.info("WebSocket results handler task was cancelled.") |
|
except Exception as e: |
|
logger.warning(f"Exception while awaiting websocket_task completion: {e}") |
|
|
|
await audio_processor.cleanup() |
|
logger.info("WebSocket endpoint cleaned up successfully.") |
|
|
|
def main(): |
|
"""Entry point for the CLI command.""" |
|
import uvicorn |
|
|
|
uvicorn_kwargs = { |
|
"app": "whisperlivekit.basic_server:app", |
|
"host":args.host, |
|
"port":args.port, |
|
"reload": False, |
|
"log_level": "info", |
|
"lifespan": "on", |
|
} |
|
|
|
ssl_kwargs = {} |
|
if args.ssl_certfile or args.ssl_keyfile: |
|
if not (args.ssl_certfile and args.ssl_keyfile): |
|
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.") |
|
ssl_kwargs = { |
|
"ssl_certfile": args.ssl_certfile, |
|
"ssl_keyfile": args.ssl_keyfile |
|
} |
|
|
|
if ssl_kwargs: |
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs} |
|
|
|
uvicorn.run(**uvicorn_kwargs) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|