|
import json |
|
import logging |
|
import threading |
|
import time |
|
import numpy as np |
|
|
|
|
|
class ServeClientBase(object): |
|
RATE = 16000 |
|
SERVER_READY = "SERVER_READY" |
|
DISCONNECT = "DISCONNECT" |
|
|
|
client_uid: str |
|
"""A unique identifier for the client.""" |
|
websocket: object |
|
"""The WebSocket connection for the client.""" |
|
send_last_n_segments: int |
|
"""Number of most recent segments to send to the client.""" |
|
no_speech_thresh: float |
|
"""Segments with no speech probability above this threshold will be discarded.""" |
|
clip_audio: bool |
|
"""Whether to clip audio with no valid segments.""" |
|
same_output_threshold: int |
|
"""Number of repeated outputs before considering it as a valid segment.""" |
|
|
|
def __init__( |
|
self, |
|
client_uid, |
|
websocket, |
|
send_last_n_segments=10, |
|
no_speech_thresh=0.45, |
|
clip_audio=False, |
|
same_output_threshold=10, |
|
): |
|
self.client_uid = client_uid |
|
self.websocket = websocket |
|
self.send_last_n_segments = send_last_n_segments |
|
self.no_speech_thresh = no_speech_thresh |
|
self.clip_audio = clip_audio |
|
self.same_output_threshold = same_output_threshold |
|
|
|
self.frames = b"" |
|
self.timestamp_offset = 0.0 |
|
self.frames_np = None |
|
self.frames_offset = 0.0 |
|
self.text = [] |
|
self.current_out = "" |
|
self.prev_out = "" |
|
self.exit = False |
|
self.same_output_count = 0 |
|
self.transcript = [] |
|
self.end_time_for_same_output = None |
|
|
|
|
|
self.lock = threading.Lock() |
|
|
|
def speech_to_text(self): |
|
""" |
|
Process an audio stream in an infinite loop, continuously transcribing the speech. |
|
|
|
This method continuously receives audio frames, performs real-time transcription, and sends |
|
transcribed segments to the client via a WebSocket connection. |
|
|
|
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. |
|
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments |
|
are sent to the client in real-time, and a history of segments is maintained to provide context. |
|
|
|
Raises: |
|
Exception: If there is an issue with audio processing or WebSocket communication. |
|
|
|
""" |
|
while True: |
|
if self.exit: |
|
logging.info("Exiting speech to text thread") |
|
break |
|
|
|
if self.frames_np is None: |
|
continue |
|
|
|
if self.clip_audio: |
|
self.clip_audio_if_no_valid_segment() |
|
|
|
input_bytes, duration = self.get_audio_chunk_for_processing() |
|
if duration < 1.0: |
|
time.sleep(0.1) |
|
continue |
|
try: |
|
input_sample = input_bytes.copy() |
|
result = self.transcribe_audio(input_sample) |
|
|
|
if result is None or self.language is None: |
|
self.timestamp_offset += duration |
|
time.sleep(0.25) |
|
continue |
|
self.handle_transcription_output(result, duration) |
|
|
|
except Exception as e: |
|
logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}") |
|
time.sleep(0.01) |
|
|
|
def transcribe_audio(self): |
|
raise NotImplementedError |
|
|
|
def handle_transcription_output(self, result, duration): |
|
raise NotImplementedError |
|
|
|
def format_segment(self, start, end, text, completed=False): |
|
""" |
|
Formats a transcription segment with precise start and end times alongside the transcribed text. |
|
|
|
Args: |
|
start (float): The start time of the transcription segment in seconds. |
|
end (float): The end time of the transcription segment in seconds. |
|
text (str): The transcribed text corresponding to the segment. |
|
|
|
Returns: |
|
dict: A dictionary representing the formatted transcription segment, including |
|
'start' and 'end' times as strings with three decimal places and the 'text' |
|
of the transcription. |
|
""" |
|
return { |
|
'start': "{:.3f}".format(start), |
|
'end': "{:.3f}".format(end), |
|
'text': text, |
|
'completed': completed |
|
} |
|
|
|
def add_frames(self, frame_np): |
|
""" |
|
Add audio frames to the ongoing audio stream buffer. |
|
|
|
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition |
|
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size |
|
to prevent excessive memory usage. |
|
|
|
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds |
|
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided |
|
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. |
|
|
|
Args: |
|
frame_np (numpy.ndarray): The audio frame data as a NumPy array. |
|
|
|
""" |
|
self.lock.acquire() |
|
if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: |
|
self.frames_offset += 30.0 |
|
self.frames_np = self.frames_np[int(30*self.RATE):] |
|
|
|
|
|
|
|
if self.timestamp_offset < self.frames_offset: |
|
self.timestamp_offset = self.frames_offset |
|
if self.frames_np is None: |
|
self.frames_np = frame_np.copy() |
|
else: |
|
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) |
|
self.lock.release() |
|
|
|
def clip_audio_if_no_valid_segment(self): |
|
""" |
|
Update the timestamp offset based on audio buffer status. |
|
Clip audio if the current chunk exceeds 30 seconds, this basically implies that |
|
no valid segment for the last 30 seconds from whisper |
|
""" |
|
with self.lock: |
|
if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: |
|
duration = self.frames_np.shape[0] / self.RATE |
|
self.timestamp_offset = self.frames_offset + duration - 5 |
|
|
|
def get_audio_chunk_for_processing(self): |
|
""" |
|
Retrieves the next chunk of audio data for processing based on the current offsets. |
|
|
|
Calculates which part of the audio data should be processed next, based on |
|
the difference between the current timestamp offset and the frame's offset, scaled by |
|
the audio sample rate (RATE). It then returns this chunk of audio data along with its |
|
duration in seconds. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- input_bytes (np.ndarray): The next chunk of audio data to be processed. |
|
- duration (float): The duration of the audio chunk in seconds. |
|
""" |
|
with self.lock: |
|
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE) |
|
input_bytes = self.frames_np[int(samples_take):].copy() |
|
duration = input_bytes.shape[0] / self.RATE |
|
return input_bytes, duration |
|
|
|
def prepare_segments(self, last_segment=None): |
|
""" |
|
Prepares the segments of transcribed text to be sent to the client. |
|
|
|
This method compiles the recent segments of transcribed text, ensuring that only the |
|
specified number of the most recent segments are included. It also appends the most |
|
recent segment of text if provided (which is considered incomplete because of the possibility |
|
of the last word being truncated in the audio chunk). |
|
|
|
Args: |
|
last_segment (str, optional): The most recent segment of transcribed text to be added |
|
to the list of segments. Defaults to None. |
|
|
|
Returns: |
|
list: A list of transcribed text segments to be sent to the client. |
|
""" |
|
segments = [] |
|
if len(self.transcript) >= self.send_last_n_segments: |
|
segments = self.transcript[-self.send_last_n_segments:].copy() |
|
else: |
|
segments = self.transcript.copy() |
|
if last_segment is not None: |
|
segments = segments + [last_segment] |
|
return segments |
|
|
|
def get_audio_chunk_duration(self, input_bytes): |
|
""" |
|
Calculates the duration of the provided audio chunk. |
|
|
|
Args: |
|
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration. |
|
|
|
Returns: |
|
float: The duration of the audio chunk in seconds. |
|
""" |
|
return input_bytes.shape[0] / self.RATE |
|
|
|
def send_transcription_to_client(self, segments): |
|
""" |
|
Sends the specified transcription segments to the client over the websocket connection. |
|
|
|
This method formats the transcription segments into a JSON object and attempts to send |
|
this object to the client. If an error occurs during the send operation, it logs the error. |
|
|
|
Returns: |
|
segments (list): A list of transcription segments to be sent to the client. |
|
""" |
|
try: |
|
self.websocket.send( |
|
json.dumps({ |
|
"uid": self.client_uid, |
|
"segments": segments, |
|
}) |
|
) |
|
except Exception as e: |
|
logging.error(f"[ERROR]: Sending data to client: {e}") |
|
|
|
def disconnect(self): |
|
""" |
|
Notify the client of disconnection and send a disconnect message. |
|
|
|
This method sends a disconnect message to the client via the WebSocket connection to notify them |
|
that the transcription service is disconnecting gracefully. |
|
|
|
""" |
|
self.websocket.send(json.dumps({ |
|
"uid": self.client_uid, |
|
"message": self.DISCONNECT |
|
})) |
|
|
|
def cleanup(self): |
|
""" |
|
Perform cleanup tasks before exiting the transcription service. |
|
|
|
This method performs necessary cleanup tasks, including stopping the transcription thread, marking |
|
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources |
|
associated with the transcription process. |
|
|
|
""" |
|
logging.info("Cleaning up.") |
|
self.exit = True |
|
|
|
def get_segment_no_speech_prob(self, segment): |
|
return getattr(segment, "no_speech_prob", 0) |
|
|
|
def get_segment_start(self, segment): |
|
return getattr(segment, "start", getattr(segment, "start_ts", 0)) |
|
|
|
def get_segment_end(self, segment): |
|
return getattr(segment, "end", getattr(segment, "end_ts", 0)) |
|
|
|
def update_segments(self, segments, duration): |
|
""" |
|
Processes the segments from Whisper and updates the transcript. |
|
Uses helper methods to account for differences between backends. |
|
|
|
Args: |
|
segments (list): List of segments returned by the transcriber. |
|
duration (float): Duration of the current audio chunk. |
|
|
|
Returns: |
|
dict or None: The last processed segment (if any). |
|
""" |
|
offset = None |
|
self.current_out = '' |
|
last_segment = None |
|
|
|
|
|
|
|
if len(segments) > 1 and self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh: |
|
for s in segments[:-1]: |
|
text_ = s.text |
|
self.text.append(text_) |
|
with self.lock: |
|
start = self.timestamp_offset + self.get_segment_start(s) |
|
end = self.timestamp_offset + min(duration, self.get_segment_end(s)) |
|
if start >= end: |
|
continue |
|
if self.get_segment_no_speech_prob(s) > self.no_speech_thresh: |
|
continue |
|
self.transcript.append(self.format_segment(start, end, text_, completed=True)) |
|
offset = min(duration, self.get_segment_end(s)) |
|
|
|
|
|
if self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh: |
|
self.current_out += segments[-1].text |
|
with self.lock: |
|
last_segment = self.format_segment( |
|
self.timestamp_offset + self.get_segment_start(segments[-1]), |
|
self.timestamp_offset + min(duration, self.get_segment_end(segments[-1])), |
|
self.current_out, |
|
completed=False |
|
) |
|
|
|
|
|
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '': |
|
self.same_output_count += 1 |
|
|
|
|
|
|
|
if self.end_time_for_same_output is None: |
|
self.end_time_for_same_output = self.get_segment_end(segments[-1]) |
|
time.sleep(0.1) |
|
else: |
|
self.same_output_count = 0 |
|
self.end_time_for_same_output = None |
|
|
|
|
|
|
|
if self.same_output_count > self.same_output_threshold: |
|
if not self.text or self.text[-1].strip().lower() != self.current_out.strip().lower(): |
|
self.text.append(self.current_out) |
|
with self.lock: |
|
self.transcript.append(self.format_segment( |
|
self.timestamp_offset, |
|
self.timestamp_offset + min(duration, self.end_time_for_same_output), |
|
self.current_out, |
|
completed=True |
|
)) |
|
self.current_out = '' |
|
offset = min(duration, self.end_time_for_same_output) |
|
self.same_output_count = 0 |
|
last_segment = None |
|
self.end_time_for_same_output = None |
|
else: |
|
self.prev_out = self.current_out |
|
|
|
if offset is not None: |
|
with self.lock: |
|
self.timestamp_offset += offset |
|
|
|
return last_segment |
|
|