Spaces:
Sleeping
Sleeping
File size: 8,440 Bytes
5ef2360 ce1a70a 382a8a5 5ef2360 e0c1694 5ef2360 489ba9a 5ef2360 c24c148 5ef2360 e0c1694 5ef2360 e0c1694 489ba9a e0c1694 489ba9a 5ef2360 5c44b80 aacc5eb 0d64afb 5ef2360 489ba9a e0c1694 5c44b80 e0c1694 5c44b80 e0c1694 5c44b80 e0c1694 5c44b80 5ef2360 489ba9a 5ef2360 5c44b80 5ef2360 5c44b80 5ef2360 5c44b80 5ef2360 edfee48 bebdee6 061790e 5ef2360 b083792 bebdee6 b083792 570de39 bebdee6 5ef2360 061790e bebdee6 5ef2360 e0c1694 5ef2360 061790e 9784bd2 5ef2360 489ba9a 9784bd2 061790e bebdee6 489ba9a 9784bd2 489ba9a bebdee6 489ba9a e0c1694 489ba9a 5ef2360 437ed2e c1d862e 5ef2360 e0c1694 5ef2360 74081c9 e0c1694 fe3ff14 e0c1694 e523c05 ea07abc 623d479 437ed2e 5ef2360 ea07abc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import logging
import json
import torch
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
get_cloudflare_turn_credentials_async,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.device import get_device, get_torch_and_np_dtypes
load_dotenv()
logger = logging.getLogger(__name__)
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
device = get_device(force_cpu=False)
use_device_map = True if device == "cuda" else False
try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False
try_use_flash_attention = False
#try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
logger.info(f"Using language: {LANGUAGE}")
# Initialize the model (use flash attention on cuda if possible)
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
device_map="auto" if use_device_map else None,
)
if not use_device_map:
model.to(device)
except RuntimeError as e:
try:
logger.warning("Falling back to device_map=None")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
device_map=None,
)
model.to(device)
except RuntimeError as e:
try:
logger.warning("Disabling flash attention")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
raise
processor = AutoProcessor.from_pretrained(MODEL_ID)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype
)
# Try to compile the model
try:
if try_compile_model:
transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
else:
logger.warning("Proceeding without compiling the model (requirements not met)")
except Exception as e:
logger.warning(f"Error compiling model: {e}")
logger.warning("Proceeding without compiling the model")
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.random.rand(16000).astype(np_dtype)
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
audio_to_bytes(audio), # pass bytes
#audio_array, # pass numpy array
chunk_length_s=5,
batch_size=1,
generate_kwargs={
#"compression_ratio_threshold": 1.35,
#"no_speech_threshold": 0.6,
#"logprob_threshold": -1.0,
#"num_beams": 1,
#"condition_on_prev_tokens": False,
#"temperature": (0.0, 0.2, 0.4, 0.6),
#"return_timestamps": True,
#"task": "transcribe",
"task": "translate",
"language": LANGUAGE,
}
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks passed to the VAD model (default 0.6)
audio_chunk_duration=0.5,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.1,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.1,
# Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model. (default -inf)
max_continuous_speech_s=15
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.5,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=250,
# Max duration of speech chunks, longer will be split at the timestamp of the last silence that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) (used internally in the VAD algorithm to split the audio that's passed to the algorithm)
max_speech_duration_s=5,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=200,
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=1024,
# Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=200,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
concurrency_limit=6
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
stream.mount(app)
@app.get("/")
async def index():
if UI_TYPE == "base":
html_content = open("static/index.html").read()
elif UI_TYPE == "screen":
html_content = open("static/index-screen.html").read()
rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
logger.info(f"RTC configuration: {rtc_configuration}")
html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration))
return HTMLResponse(content=html_content)
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=7860) |