Sofia Casadei
fix use flash attn
0fa1945
raw
history blame
6.83 kB
import os
import logging
import json
import torch
import asyncio
import subprocess
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
get_cloudflare_turn_credentials_async,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.logger_config import setup_logging
from utils.device import get_device, get_torch_and_np_dtypes
load_dotenv()
setup_logging()
logger = logging.getLogger(__name__)
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
logger.info(f"Using language: {LANGUAGE}")
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention,
device_map="auto" if device == "cuda" else None
)
#model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
raise
processor = AutoProcessor.from_pretrained(MODEL_ID)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
#device=device,
)
#if device == "cuda":
# transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
#audio_to_bytes(audio), # pass bytes
audio_array, # pass numpy array
chunk_length_s=3,
batch_size=1,
generate_kwargs={
'task': 'transcribe',
'language': LANGUAGE,
},
#return_timestamps="word"
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks passed to the VAD model (default 0.6)
audio_chunk_duration=0.6,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.1,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.1,
# Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model. (default -inf)
max_continuous_speech_s=6
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.5,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=250,
# Max duration of speech chunks, longer will be split at the timestamp of the last silence that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) (used internally in the VAD algorithm to split the audio that's passed to the algorithm)
max_speech_duration_s=3,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=100,
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=512,
# Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=200,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
concurrency_limit=6
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
stream.mount(app)
@app.get("/")
async def index():
if UI_TYPE == "base":
html_content = open("static/index.html").read()
elif UI_TYPE == "screen":
html_content = open("static/index-screen.html").read()
rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
logger.info(f"RTC configuration: {rtc_configuration}")
html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration))
return HTMLResponse(content=html_content)
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=7860)