File size: 8,440 Bytes
5ef2360
 
ce1a70a
382a8a5
5ef2360
 
 
 
 
 
e0c1694
5ef2360
 
 
 
 
489ba9a
5ef2360
c24c148
5ef2360
 
 
 
 
 
 
 
e0c1694
5ef2360
 
 
 
 
e0c1694
 
 
489ba9a
e0c1694
489ba9a
5ef2360
5c44b80
 
aacc5eb
 
0d64afb
5ef2360
 
 
489ba9a
e0c1694
 
5c44b80
e0c1694
 
5c44b80
 
 
e0c1694
5c44b80
 
e0c1694
5c44b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef2360
489ba9a
5ef2360
 
 
 
 
 
5c44b80
5ef2360
5c44b80
 
 
 
 
 
 
 
 
 
5ef2360
 
 
5c44b80
5ef2360
 
 
 
 
 
 
 
edfee48
 
bebdee6
061790e
5ef2360
b083792
 
 
bebdee6
 
 
b083792
570de39
 
bebdee6
 
5ef2360
 
 
 
 
 
 
 
061790e
bebdee6
5ef2360
e0c1694
5ef2360
 
061790e
9784bd2
5ef2360
489ba9a
 
 
 
9784bd2
061790e
bebdee6
489ba9a
9784bd2
489ba9a
bebdee6
489ba9a
e0c1694
489ba9a
5ef2360
 
 
 
 
 
 
 
 
 
437ed2e
c1d862e
5ef2360
 
 
e0c1694
5ef2360
74081c9
 
 
e0c1694
 
fe3ff14
e0c1694
 
e523c05
ea07abc
623d479
437ed2e
5ef2360
 
 
 
 
 
 
 
 
 
 
 
 
 
ea07abc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import logging
import json
import torch

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    Stream,
    AlgoOptions,
    SileroVadOptions,
    audio_to_bytes,
    get_cloudflare_turn_credentials_async,
)
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    pipeline,
)
from transformers.utils import is_flash_attn_2_available

from utils.device import get_device, get_torch_and_np_dtypes

load_dotenv()
logger = logging.getLogger(__name__)


UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
LANGUAGE = os.getenv("LANGUAGE", "english").lower()

device = get_device(force_cpu=False)
use_device_map = True if device == "cuda" else False
try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False
try_use_flash_attention = False
#try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False

torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")

logger.info(f"Loading Whisper model: {MODEL_ID}")
logger.info(f"Using language: {LANGUAGE}")

# Initialize the model (use flash attention on cuda if possible)
try:
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_ID,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True,
        attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
        device_map="auto" if use_device_map else None,
    )
    if not use_device_map:
        model.to(device)
except RuntimeError as e:
    try:
        logger.warning("Falling back to device_map=None")
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            MODEL_ID,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
            device_map=None,
        )
        model.to(device)
    except RuntimeError as e:
        try:
            logger.warning("Disabling flash attention")
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                MODEL_ID,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True,
                attn_implementation="sdpa",
            )
            model.to(device)
        except Exception as e:
            logger.error(f"Error loading ASR model: {e}")
            logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
            raise

processor = AutoProcessor.from_pretrained(MODEL_ID)

transcribe_pipeline = pipeline(
    task="automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype
)

# Try to compile the model
try:
    if try_compile_model:
        transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
    else:
        logger.warning("Proceeding without compiling the model (requirements not met)")
except Exception as e:
    logger.warning(f"Error compiling model: {e}")
    logger.warning("Proceeding without compiling the model")

# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.random.rand(16000).astype(np_dtype)
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")

async def transcribe(audio: tuple[int, np.ndarray]):
    sample_rate, audio_array = audio
    logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
    
    outputs = transcribe_pipeline(
        audio_to_bytes(audio), # pass bytes
        #audio_array, # pass numpy array
        chunk_length_s=5,
        batch_size=1,
        generate_kwargs={
            #"compression_ratio_threshold": 1.35,
            #"no_speech_threshold": 0.6,
            #"logprob_threshold": -1.0,
            #"num_beams": 1,
            #"condition_on_prev_tokens": False,
            #"temperature": (0.0, 0.2, 0.4, 0.6),
            #"return_timestamps": True,
            #"task": "transcribe",
            "task": "translate",
            "language": LANGUAGE,
        }
    )
    yield AdditionalOutputs(outputs["text"].strip())

logger.info("Initializing FastRTC stream")
stream = Stream(
    handler=ReplyOnPause(
        transcribe,
        algo_options=AlgoOptions(
            # Duration in seconds of audio chunks passed to the VAD model (default 0.6) 
            audio_chunk_duration=0.5,
            # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
            started_talking_threshold=0.1,
            # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
            speech_threshold=0.1,
            # Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model. (default -inf)
            max_continuous_speech_s=15
        ),
        model_options=SileroVadOptions(
            # Threshold for what is considered speech (default 0.5)
            threshold=0.5,
            # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
            min_speech_duration_ms=250,
            # Max duration of speech chunks, longer will be split at the timestamp of the last silence that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) (used internally in the VAD algorithm to split the audio that's passed to the algorithm)
            max_speech_duration_s=5,
            # Wait for ms at the end of each speech chunk before separating it (default 2000)
            min_silence_duration_ms=200,
            # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
            window_size_samples=1024,
            # Final speech chunks are padded by speech_pad_ms each side (default 400)
            speech_pad_ms=200,
        ),
    ),
    # send-receive: bidirectional streaming (default)
    # send: client to server only
    # receive: server to client only
    modality="audio",
    mode="send",
    additional_outputs=[
        gr.Textbox(label="Transcript"),
    ],
    additional_outputs_handler=lambda current, new: current + " " + new,
    rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
    concurrency_limit=6
)

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
stream.mount(app)

@app.get("/")
async def index():
    if UI_TYPE == "base":
        html_content = open("static/index.html").read()
    elif UI_TYPE == "screen":
        html_content = open("static/index-screen.html").read()
    
    rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
    logger.info(f"RTC configuration: {rtc_configuration}")
    html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration))
    return HTMLResponse(content=html_content)

@app.get("/transcript")
def _(webrtc_id: str):
    logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
    async def output_stream():
        try:
            async for output in stream.output_stream(webrtc_id):
                transcript = output.args[0]
                logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
                yield f"event: output\ndata: {transcript}\n\n"
        except Exception as e:
            logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
            raise

    return StreamingResponse(output_stream(), media_type="text/event-stream")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="localhost", port=7860)