File size: 8,584 Bytes
7222c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import json
import logging
import threading
import time
import torch
from whisper_live.transcriber.transcriber_faster_whisper import WhisperModel
from whisper_live.backend.base import ServeClientBase
class ServeClientFasterWhisper(ServeClientBase):
SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()
def __init__(
self,
websocket,
task="transcribe",
device=None,
language=None,
client_uid=None,
model="small.en",
initial_prompt=None,
vad_parameters=None,
use_vad=True,
single_model=False,
send_last_n_segments=10,
no_speech_thresh=0.45,
clip_audio=False,
same_output_threshold=10,
):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
to the client to indicate that the server is ready.
Args:
websocket (WebSocket): The WebSocket connection for the client.
task (str, optional): The task type, e.g., "transcribe". Defaults to "transcribe".
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
language (str, optional): The language for transcription. Defaults to None.
client_uid (str, optional): A unique identifier for the client. Defaults to None.
model (str, optional): The whisper model size. Defaults to 'small.en'
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
"""
super().__init__(
client_uid,
websocket,
send_last_n_segments,
no_speech_thresh,
clip_audio,
same_output_threshold,
)
self.model_sizes = [
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
"medium", "medium.en", "large-v2", "large-v3", "distil-small.en",
"distil-medium.en", "distil-large-v2", "distil-large-v3",
"large-v3-turbo", "turbo"
]
self.model_size_or_path = model
self.language = "en" if self.model_size_or_path.endswith("en") else language
self.task = task
self.initial_prompt = initial_prompt
self.vad_parameters = vad_parameters or {"onset": 0.5}
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
major, _ = torch.cuda.get_device_capability(device)
self.compute_type = "float16" if major >= 7 else "float32"
else:
self.compute_type = "int8"
if self.model_size_or_path is None:
return
logging.info(f"Using Device={device} with precision {self.compute_type}")
try:
if single_model:
if ServeClientFasterWhisper.SINGLE_MODEL is None:
self.create_model(device)
ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
else:
self.create_model(device)
except Exception as e:
logging.error(f"Failed to load model: {e}")
self.websocket.send(json.dumps({
"uid": self.client_uid,
"status": "ERROR",
"message": f"Failed to load model: {str(self.model_size_or_path)}"
}))
self.websocket.close()
return
self.use_vad = use_vad
# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
self.trans_thread.start()
self.websocket.send(
json.dumps(
{
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "faster_whisper"
}
)
)
def create_model(self, device):
"""
Instantiates a new model, sets it as the transcriber.
"""
self.transcriber = WhisperModel(
self.model_size_or_path,
device=device,
compute_type=self.compute_type,
local_files_only=False,
)
def check_valid_model(self, model_size):
"""
Check if it's a valid whisper model size.
Args:
model_size (str): The name of the model size to check.
Returns:
str: The model size if valid, None otherwise.
"""
if model_size not in self.model_sizes:
self.websocket.send(
json.dumps(
{
"uid": self.client_uid,
"status": "ERROR",
"message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}"
}
)
)
return None
return model_size
def set_language(self, info):
"""
Updates the language attribute based on the detected language information.
Args:
info (object): An object containing the detected language and its probability. This object
must have at least two attributes: `language`, a string indicating the detected
language, and `language_probability`, a float representing the confidence level
of the language detection.
"""
if info.language_probability > 0.5:
self.language = info.language
logging.info(f"Detected language {self.language} with probability {info.language_probability}")
self.websocket.send(json.dumps(
{"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
def transcribe_audio(self, input_sample):
"""
Transcribes the provided audio sample using the configured transcriber instance.
If the language has not been set, it updates the session's language based on the transcription
information.
Args:
input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
array representing the audio data.
Returns:
The transcription result from the transcriber. The exact format of this result
depends on the implementation of the `transcriber.transcribe` method but typically
includes the transcribed text.
"""
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
result, info = self.transcriber.transcribe(
input_sample,
initial_prompt=self.initial_prompt,
language=self.language,
task=self.task,
vad_filter=self.use_vad,
vad_parameters=self.vad_parameters if self.use_vad else None)
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()
if self.language is None and info is not None:
self.set_language(info)
return result
def handle_transcription_output(self, result, duration):
"""
Handle the transcription output, updating the transcript and sending data to the client.
Args:
result (str): The result from whisper inference i.e. the list of segments.
duration (float): Duration of the transcribed audio chunk.
"""
segments = []
if len(result):
self.t_start = None
last_segment = self.update_segments(result, duration)
segments = self.prepare_segments(last_segment)
if len(segments):
self.send_transcription_to_client(segments)
|