ai-server / whisper_live /backend /faster_whisper_backend.py
nuernie
initial commit
7222c68
import json
import logging
import threading
import time
import torch
from whisper_live.transcriber.transcriber_faster_whisper import WhisperModel
from whisper_live.backend.base import ServeClientBase
class ServeClientFasterWhisper(ServeClientBase):
SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()
def __init__(
self,
websocket,
task="transcribe",
device=None,
language=None,
client_uid=None,
model="small.en",
initial_prompt=None,
vad_parameters=None,
use_vad=True,
single_model=False,
send_last_n_segments=10,
no_speech_thresh=0.45,
clip_audio=False,
same_output_threshold=10,
):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
to the client to indicate that the server is ready.
Args:
websocket (WebSocket): The WebSocket connection for the client.
task (str, optional): The task type, e.g., "transcribe". Defaults to "transcribe".
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
language (str, optional): The language for transcription. Defaults to None.
client_uid (str, optional): A unique identifier for the client. Defaults to None.
model (str, optional): The whisper model size. Defaults to 'small.en'
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
"""
super().__init__(
client_uid,
websocket,
send_last_n_segments,
no_speech_thresh,
clip_audio,
same_output_threshold,
)
self.model_sizes = [
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
"medium", "medium.en", "large-v2", "large-v3", "distil-small.en",
"distil-medium.en", "distil-large-v2", "distil-large-v3",
"large-v3-turbo", "turbo"
]
self.model_size_or_path = model
self.language = "en" if self.model_size_or_path.endswith("en") else language
self.task = task
self.initial_prompt = initial_prompt
self.vad_parameters = vad_parameters or {"onset": 0.5}
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
major, _ = torch.cuda.get_device_capability(device)
self.compute_type = "float16" if major >= 7 else "float32"
else:
self.compute_type = "int8"
if self.model_size_or_path is None:
return
logging.info(f"Using Device={device} with precision {self.compute_type}")
try:
if single_model:
if ServeClientFasterWhisper.SINGLE_MODEL is None:
self.create_model(device)
ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
else:
self.create_model(device)
except Exception as e:
logging.error(f"Failed to load model: {e}")
self.websocket.send(json.dumps({
"uid": self.client_uid,
"status": "ERROR",
"message": f"Failed to load model: {str(self.model_size_or_path)}"
}))
self.websocket.close()
return
self.use_vad = use_vad
# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
self.trans_thread.start()
self.websocket.send(
json.dumps(
{
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "faster_whisper"
}
)
)
def create_model(self, device):
"""
Instantiates a new model, sets it as the transcriber.
"""
self.transcriber = WhisperModel(
self.model_size_or_path,
device=device,
compute_type=self.compute_type,
local_files_only=False,
)
def check_valid_model(self, model_size):
"""
Check if it's a valid whisper model size.
Args:
model_size (str): The name of the model size to check.
Returns:
str: The model size if valid, None otherwise.
"""
if model_size not in self.model_sizes:
self.websocket.send(
json.dumps(
{
"uid": self.client_uid,
"status": "ERROR",
"message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}"
}
)
)
return None
return model_size
def set_language(self, info):
"""
Updates the language attribute based on the detected language information.
Args:
info (object): An object containing the detected language and its probability. This object
must have at least two attributes: `language`, a string indicating the detected
language, and `language_probability`, a float representing the confidence level
of the language detection.
"""
if info.language_probability > 0.5:
self.language = info.language
logging.info(f"Detected language {self.language} with probability {info.language_probability}")
self.websocket.send(json.dumps(
{"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
def transcribe_audio(self, input_sample):
"""
Transcribes the provided audio sample using the configured transcriber instance.
If the language has not been set, it updates the session's language based on the transcription
information.
Args:
input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
array representing the audio data.
Returns:
The transcription result from the transcriber. The exact format of this result
depends on the implementation of the `transcriber.transcribe` method but typically
includes the transcribed text.
"""
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
result, info = self.transcriber.transcribe(
input_sample,
initial_prompt=self.initial_prompt,
language=self.language,
task=self.task,
vad_filter=self.use_vad,
vad_parameters=self.vad_parameters if self.use_vad else None)
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()
if self.language is None and info is not None:
self.set_language(info)
return result
def handle_transcription_output(self, result, duration):
"""
Handle the transcription output, updating the transcript and sending data to the client.
Args:
result (str): The result from whisper inference i.e. the list of segments.
duration (float): Duration of the transcribed audio chunk.
"""
segments = []
if len(result):
self.t_start = None
last_segment = self.update_segments(result, duration)
segments = self.prepare_segments(last_segment)
if len(segments):
self.send_transcription_to_client(segments)