|
import subprocess |
|
import time |
|
import json |
|
import unittest |
|
from unittest import mock |
|
|
|
import numpy as np |
|
import jiwer |
|
|
|
from websockets.exceptions import ConnectionClosed |
|
from whisper_live.server import TranscriptionServer, BackendType, ClientManager |
|
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient |
|
from whisper.normalizers import EnglishTextNormalizer |
|
|
|
|
|
class TestTranscriptionServerInitialization(unittest.TestCase): |
|
def test_initialization(self): |
|
server = TranscriptionServer() |
|
server.client_manager = ClientManager(max_clients=4, max_connection_time=600) |
|
self.assertEqual(server.client_manager.max_clients, 4) |
|
self.assertEqual(server.client_manager.max_connection_time, 600) |
|
self.assertDictEqual(server.client_manager.clients, {}) |
|
self.assertDictEqual(server.client_manager.start_times, {}) |
|
|
|
|
|
class TestGetWaitTime(unittest.TestCase): |
|
def setUp(self): |
|
self.server = TranscriptionServer() |
|
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600) |
|
self.server.client_manager.start_times = { |
|
'client1': time.time() - 120, |
|
'client2': time.time() - 300 |
|
} |
|
self.server.client_manager.max_connection_time = 600 |
|
|
|
def test_get_wait_time(self): |
|
expected_wait_time = (600 - (time.time() - self.server.client_manager.start_times['client2'])) / 60 |
|
print(self.server.client_manager.get_wait_time(), expected_wait_time) |
|
self.assertAlmostEqual(self.server.client_manager.get_wait_time(), expected_wait_time, places=2) |
|
|
|
|
|
class TestServerConnection(unittest.TestCase): |
|
def setUp(self): |
|
self.server = TranscriptionServer() |
|
|
|
@mock.patch('websockets.WebSocketCommonProtocol') |
|
def test_connection(self, mock_websocket): |
|
mock_websocket.recv.return_value = json.dumps({ |
|
'uid': 'test_client', |
|
'language': 'en', |
|
'task': 'transcribe', |
|
'model': 'tiny.en' |
|
}) |
|
self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) |
|
|
|
@mock.patch('websockets.WebSocketCommonProtocol') |
|
def test_recv_audio_exception_handling(self, mock_websocket): |
|
mock_websocket.recv.side_effect = [json.dumps({ |
|
'uid': 'test_client', |
|
'language': 'en', |
|
'task': 'transcribe', |
|
'model': 'tiny.en' |
|
}), np.array([1, 2, 3]).tobytes()] |
|
|
|
with self.assertLogs(level="ERROR"): |
|
self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) |
|
|
|
self.assertNotIn(mock_websocket, self.server.client_manager.clients) |
|
|
|
|
|
class TestServerInferenceAccuracy(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls): |
|
cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio') |
|
cls.mock_pyaudio = cls.mock_pyaudio_patch.start() |
|
cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock() |
|
|
|
cls.server_process = subprocess.Popen(["python", "run_server.py"]) |
|
time.sleep(2) |
|
|
|
@classmethod |
|
def tearDownClass(cls): |
|
cls.server_process.terminate() |
|
cls.server_process.wait() |
|
|
|
def setUp(self): |
|
self.normalizer = EnglishTextNormalizer() |
|
|
|
def check_prediction(self, srt_path): |
|
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!" |
|
with open(srt_path, "r") as f: |
|
lines = f.readlines() |
|
prediction = " ".join([line.strip() for line in lines[2::4]]) |
|
prediction_normalized = self.normalizer(prediction) |
|
gt_normalized = self.normalizer(gt) |
|
|
|
|
|
wer_score = jiwer.wer(gt_normalized, prediction_normalized) |
|
self.assertLess(wer_score, 0.05) |
|
|
|
def test_inference(self): |
|
client = TranscriptionClient( |
|
"localhost", "9090", model="base.en", lang="en", |
|
) |
|
client("assets/jfk.flac") |
|
self.check_prediction("output.srt") |
|
|
|
def test_simultaneous_inference(self): |
|
client1 = Client( |
|
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt") |
|
client2 = Client( |
|
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt") |
|
tee = TranscriptionTeeClient([client1, client2]) |
|
tee("assets/jfk.flac") |
|
self.check_prediction("transcript1.srt") |
|
self.check_prediction("transcript2.srt") |
|
|
|
|
|
class TestExceptionHandling(unittest.TestCase): |
|
def setUp(self): |
|
self.server = TranscriptionServer() |
|
|
|
@mock.patch('websockets.WebSocketCommonProtocol') |
|
def test_connection_closed_exception(self, mock_websocket): |
|
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock()) |
|
|
|
with self.assertLogs(level="INFO") as log: |
|
self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) |
|
self.assertTrue(any("Connection closed by client" in message for message in log.output)) |
|
|
|
@mock.patch('websockets.WebSocketCommonProtocol') |
|
def test_json_decode_exception(self, mock_websocket): |
|
mock_websocket.recv.return_value = "invalid json" |
|
|
|
with self.assertLogs(level="ERROR") as log: |
|
self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) |
|
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output)) |
|
|
|
@mock.patch('websockets.WebSocketCommonProtocol') |
|
def test_unexpected_exception_handling(self, mock_websocket): |
|
mock_websocket.recv.side_effect = RuntimeError("Unexpected error") |
|
|
|
with self.assertLogs(level="ERROR") as log: |
|
self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) |
|
for message in log.output: |
|
print(message) |
|
print() |
|
self.assertTrue(any("Unexpected error" in message for message in log.output)) |
|
|