import os import logging import json import torch import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, audio_to_bytes, get_cloudflare_turn_credentials_async, ) from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, ) from transformers.utils import is_flash_attn_2_available from utils.device import get_device, get_torch_and_np_dtypes load_dotenv() logger = logging.getLogger(__name__) UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") LANGUAGE = os.getenv("LANGUAGE", "english").lower() device = get_device(force_cpu=False) use_device_map = True if device == "cuda" else False try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False try_use_flash_attention = False #try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") logger.info(f"Loading Whisper model: {MODEL_ID}") logger.info(f"Using language: {LANGUAGE}") # Initialize the model (use flash attention on cuda if possible) try: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", device_map="auto" if use_device_map else None, ) if not use_device_map: model.to(device) except RuntimeError as e: try: logger.warning("Falling back to device_map=None") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", device_map=None, ) model.to(device) except RuntimeError as e: try: logger.warning("Disabling flash attention") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa", ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") logger.error(f"Are you providing a valid model ID? {MODEL_ID}") raise processor = AutoProcessor.from_pretrained(MODEL_ID) transcribe_pipeline = pipeline( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype ) # Try to compile the model try: if try_compile_model: transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") else: logger.warning("Proceeding without compiling the model (requirements not met)") except Exception as e: logger.warning(f"Error compiling model: {e}") logger.warning("Proceeding without compiling the model") # Warm up the model with empty audio logger.info("Warming up Whisper model with dummy input") warmup_audio = np.random.rand(16000).astype(np_dtype) transcribe_pipeline(warmup_audio) logger.info("Model warmup complete") async def transcribe(audio: tuple[int, np.ndarray]): sample_rate, audio_array = audio logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") outputs = transcribe_pipeline( audio_to_bytes(audio), # pass bytes #audio_array, # pass numpy array chunk_length_s=5, batch_size=1, generate_kwargs={ #"compression_ratio_threshold": 1.35, #"no_speech_threshold": 0.6, #"logprob_threshold": -1.0, #"num_beams": 1, #"condition_on_prev_tokens": False, #"temperature": (0.0, 0.2, 0.4, 0.6), #"return_timestamps": True, #"task": "transcribe", "task": "translate", "language": LANGUAGE, } ) yield AdditionalOutputs(outputs["text"].strip()) logger.info("Initializing FastRTC stream") stream = Stream( handler=ReplyOnPause( transcribe, algo_options=AlgoOptions( # Duration in seconds of audio chunks passed to the VAD model (default 0.6) audio_chunk_duration=0.5, # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) started_talking_threshold=0.1, # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) speech_threshold=0.1, # Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model. (default -inf) max_continuous_speech_s=15 ), model_options=SileroVadOptions( # Threshold for what is considered speech (default 0.5) threshold=0.5, # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) min_speech_duration_ms=250, # Max duration of speech chunks, longer will be split at the timestamp of the last silence that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) (used internally in the VAD algorithm to split the audio that's passed to the algorithm) max_speech_duration_s=5, # Wait for ms at the end of each speech chunk before separating it (default 2000) min_silence_duration_ms=200, # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) window_size_samples=1024, # Final speech chunks are padded by speech_pad_ms each side (default 400) speech_pad_ms=200, ), ), # send-receive: bidirectional streaming (default) # send: client to server only # receive: server to client only modality="audio", mode="send", additional_outputs=[ gr.Textbox(label="Transcript"), ], additional_outputs_handler=lambda current, new: current + " " + new, rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None, concurrency_limit=6 ) app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") stream.mount(app) @app.get("/") async def index(): if UI_TYPE == "base": html_content = open("static/index.html").read() elif UI_TYPE == "screen": html_content = open("static/index-screen.html").read() rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None logger.info(f"RTC configuration: {rtc_configuration}") html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration)) return HTMLResponse(content=html_content) @app.get("/transcript") def _(webrtc_id: str): logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") async def output_stream(): try: async for output in stream.output_stream(webrtc_id): transcript = output.args[0] logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") yield f"event: output\ndata: {transcript}\n\n" except Exception as e: logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") raise return StreamingResponse(output_stream(), media_type="text/event-stream") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="localhost", port=7860)