File size: 2,193 Bytes
6cf18f3
104f7bd
 
 
 
 
ff49b3c
d5886b3
b9f09f7
 
104f7bd
d5886b3
 
 
 
 
b9f09f7
104f7bd
 
6cf18f3
 
5f66658
d920423
 
fd90ec3
d920423
 
5f66658
 
0cf8b89
ff49b3c
5f66658
 
6cf18f3
5fdb08e
6cf18f3
 
 
 
 
 
 
 
5fdb08e
104f7bd
6cf18f3
0cf8b89
6cf18f3
5fdb08e
b9f09f7
 
 
6cccf9e
1cea20a
104f7bd
b9f09f7
 
104f7bd
d5886b3
ff49b3c
 
104f7bd
 
 
ff49b3c
104f7bd
d5886b3
104f7bd
ff49b3c
b0d49ce
5fdb08e
104f7bd
 
5fdb08e
 
df1de84
 
d5886b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from contextlib import asynccontextmanager

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware

from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
import logging
from parse_args import parse_args
from audio import AudioProcessor

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

args = parse_args()


@asynccontextmanager
async def lifespan(app: FastAPI):
    global asr, tokenizer, diarization
    if args.transcription:
        asr, tokenizer = backend_factory(args)
        warmup_asr(asr, args.warmup_file)
    else:
        asr, tokenizer = None, None

    if args.diarization:
        from diarization.diarization_online import DiartDiarization
        diarization = DiartDiarization()
    else :
        diarization = None
    yield

app = FastAPI(lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Load demo HTML for the root endpoint
with open("web/live_transcription.html", "r", encoding="utf-8") as f:
    html = f.read()

@app.get("/")
async def get():
    return HTMLResponse(html)

@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
    audio_processor = AudioProcessor(args, asr, tokenizer)

    await websocket.accept()
    logger.info("WebSocket connection opened.")
            
    await audio_processor.create_tasks(websocket, diarization)
    try:
        while True:
            message = await websocket.receive_bytes()
            audio_processor.process_audio(message)
    except WebSocketDisconnect:
        logger.warning("WebSocket disconnected.")
    finally:
        audio_processor.cleanup()
        logger.info("WebSocket endpoint cleaned up.")

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
        log_level="info"
    )