ai-server / whisper_live /backend /trt_backend.py
nuernie
initial commit
7222c68
import json
import logging
import threading
import time
from whisper_live.backend.base import ServeClientBase
from whisper_live.transcriber.transcriber_tensorrt import WhisperTRTLLM
class ServeClientTensorRT(ServeClientBase):
SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()
def __init__(
self,
websocket,
task="transcribe",
multilingual=False,
language=None,
client_uid=None,
model=None,
single_model=False,
use_py_session=False,
max_new_tokens=225,
send_last_n_segments=10,
no_speech_thresh=0.45,
clip_audio=False,
same_output_threshold=10,
):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
to the client to indicate that the server is ready.
Args:
websocket (WebSocket): The WebSocket connection for the client.
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
language (str, optional): The language for transcription. Defaults to None.
client_uid (str, optional): A unique identifier for the client. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
use_py_session (bool, optional): Use python session or cpp session. Defaults to Cpp Session.
max_new_tokens (int, optional): Max number of tokens to generate.
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
"""
super().__init__(
client_uid,
websocket,
send_last_n_segments,
no_speech_thresh,
clip_audio,
same_output_threshold,
)
self.language = language if multilingual else "en"
self.task = task
self.eos = False
self.max_new_tokens = max_new_tokens
if single_model:
if ServeClientTensorRT.SINGLE_MODEL is None:
self.create_model(model, multilingual, use_py_session=use_py_session)
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
else:
self.create_model(model, multilingual, use_py_session=use_py_session)
# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
self.trans_thread.start()
self.websocket.send(json.dumps({
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "tensorrt"
}))
def create_model(self, model, multilingual, warmup=True, use_py_session=False):
"""
Instantiates a new model, sets it as the transcriber and does warmup if desired.
"""
self.transcriber = WhisperTRTLLM(
model,
assets_dir="assets",
device="cuda",
is_multilingual=multilingual,
language=self.language,
task=self.task,
use_py_session=use_py_session,
max_output_len=self.max_new_tokens,
)
if warmup:
self.warmup()
def warmup(self, warmup_steps=10):
"""
Warmup TensorRT since first few inferences are slow.
Args:
warmup_steps (int): Number of steps to warm up the model for.
"""
logging.info("[INFO:] Warming up TensorRT engine..")
mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac")
for i in range(warmup_steps):
self.transcriber.transcribe(mel)
def set_eos(self, eos):
"""
Sets the End of Speech (EOS) flag.
Args:
eos (bool): The value to set for the EOS flag.
"""
self.lock.acquire()
self.eos = eos
self.lock.release()
def handle_transcription_output(self, last_segment, duration):
"""
Handle the transcription output, updating the transcript and sending data to the client.
Args:
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
of the possibility of word being truncated.
duration (float): Duration of the transcribed audio chunk.
"""
segments = self.prepare_segments({"text": last_segment})
self.send_transcription_to_client(segments)
if self.eos:
self.update_timestamp_offset(last_segment, duration)
def transcribe_audio(self, input_bytes):
"""
Transcribe the audio chunk and send the results to the client.
Args:
input_bytes (np.array): The audio chunk to transcribe.
"""
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
last_segment = self.transcriber.transcribe(
mel,
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>",
)
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
if last_segment:
self.handle_transcription_output(last_segment, duration)
def update_timestamp_offset(self, last_segment, duration):
"""
Update timestamp offset and transcript.
Args:
last_segment (str): Last transcribed audio from the whisper model.
duration (float): Duration of the last audio chunk.
"""
if not len(self.transcript):
self.transcript.append({"text": last_segment + " "})
elif self.transcript[-1]["text"].strip() != last_segment:
self.transcript.append({"text": last_segment + " "})
with self.lock:
self.timestamp_offset += duration
def speech_to_text(self):
"""
Process an audio stream in an infinite loop, continuously transcribing the speech.
This method continuously receives audio frames, performs real-time transcription, and sends
transcribed segments to the client via a WebSocket connection.
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
are sent to the client in real-time, and a history of segments is maintained to provide context.
Raises:
Exception: If there is an issue with audio processing or WebSocket communication.
"""
while True:
if self.exit:
logging.info("Exiting speech to text thread")
break
if self.frames_np is None:
time.sleep(0.02) # wait for any audio to arrive
continue
self.clip_audio_if_no_valid_segment()
input_bytes, duration = self.get_audio_chunk_for_processing()
if duration < 0.4:
continue
try:
input_sample = input_bytes.copy()
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}")
self.transcribe_audio(input_sample)
except Exception as e:
logging.error(f"[ERROR]: {e}")