File size: 8,421 Bytes
7222c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import json
import logging
import threading
import time
from whisper_live.backend.base import ServeClientBase
from whisper_live.transcriber.transcriber_tensorrt import WhisperTRTLLM
class ServeClientTensorRT(ServeClientBase):
SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()
def __init__(
self,
websocket,
task="transcribe",
multilingual=False,
language=None,
client_uid=None,
model=None,
single_model=False,
use_py_session=False,
max_new_tokens=225,
send_last_n_segments=10,
no_speech_thresh=0.45,
clip_audio=False,
same_output_threshold=10,
):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
to the client to indicate that the server is ready.
Args:
websocket (WebSocket): The WebSocket connection for the client.
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
language (str, optional): The language for transcription. Defaults to None.
client_uid (str, optional): A unique identifier for the client. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
use_py_session (bool, optional): Use python session or cpp session. Defaults to Cpp Session.
max_new_tokens (int, optional): Max number of tokens to generate.
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
"""
super().__init__(
client_uid,
websocket,
send_last_n_segments,
no_speech_thresh,
clip_audio,
same_output_threshold,
)
self.language = language if multilingual else "en"
self.task = task
self.eos = False
self.max_new_tokens = max_new_tokens
if single_model:
if ServeClientTensorRT.SINGLE_MODEL is None:
self.create_model(model, multilingual, use_py_session=use_py_session)
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
else:
self.create_model(model, multilingual, use_py_session=use_py_session)
# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
self.trans_thread.start()
self.websocket.send(json.dumps({
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "tensorrt"
}))
def create_model(self, model, multilingual, warmup=True, use_py_session=False):
"""
Instantiates a new model, sets it as the transcriber and does warmup if desired.
"""
self.transcriber = WhisperTRTLLM(
model,
assets_dir="assets",
device="cuda",
is_multilingual=multilingual,
language=self.language,
task=self.task,
use_py_session=use_py_session,
max_output_len=self.max_new_tokens,
)
if warmup:
self.warmup()
def warmup(self, warmup_steps=10):
"""
Warmup TensorRT since first few inferences are slow.
Args:
warmup_steps (int): Number of steps to warm up the model for.
"""
logging.info("[INFO:] Warming up TensorRT engine..")
mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac")
for i in range(warmup_steps):
self.transcriber.transcribe(mel)
def set_eos(self, eos):
"""
Sets the End of Speech (EOS) flag.
Args:
eos (bool): The value to set for the EOS flag.
"""
self.lock.acquire()
self.eos = eos
self.lock.release()
def handle_transcription_output(self, last_segment, duration):
"""
Handle the transcription output, updating the transcript and sending data to the client.
Args:
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
of the possibility of word being truncated.
duration (float): Duration of the transcribed audio chunk.
"""
segments = self.prepare_segments({"text": last_segment})
self.send_transcription_to_client(segments)
if self.eos:
self.update_timestamp_offset(last_segment, duration)
def transcribe_audio(self, input_bytes):
"""
Transcribe the audio chunk and send the results to the client.
Args:
input_bytes (np.array): The audio chunk to transcribe.
"""
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
last_segment = self.transcriber.transcribe(
mel,
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>",
)
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
if last_segment:
self.handle_transcription_output(last_segment, duration)
def update_timestamp_offset(self, last_segment, duration):
"""
Update timestamp offset and transcript.
Args:
last_segment (str): Last transcribed audio from the whisper model.
duration (float): Duration of the last audio chunk.
"""
if not len(self.transcript):
self.transcript.append({"text": last_segment + " "})
elif self.transcript[-1]["text"].strip() != last_segment:
self.transcript.append({"text": last_segment + " "})
with self.lock:
self.timestamp_offset += duration
def speech_to_text(self):
"""
Process an audio stream in an infinite loop, continuously transcribing the speech.
This method continuously receives audio frames, performs real-time transcription, and sends
transcribed segments to the client via a WebSocket connection.
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
are sent to the client in real-time, and a history of segments is maintained to provide context.
Raises:
Exception: If there is an issue with audio processing or WebSocket communication.
"""
while True:
if self.exit:
logging.info("Exiting speech to text thread")
break
if self.frames_np is None:
time.sleep(0.02) # wait for any audio to arrive
continue
self.clip_audio_if_no_valid_segment()
input_bytes, duration = self.get_audio_chunk_for_processing()
if duration < 0.4:
continue
try:
input_sample = input_bytes.copy()
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}")
self.transcribe_audio(input_sample)
except Exception as e:
logging.error(f"[ERROR]: {e}")
|