ai-server / tests /test_server.py
nuernie
initial commit
7222c68
raw
history blame
6 kB
import subprocess
import time
import json
import unittest
from unittest import mock
import numpy as np
import jiwer
from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer
class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.assertEqual(server.client_manager.max_clients, 4)
self.assertEqual(server.client_manager.max_connection_time, 600)
self.assertDictEqual(server.client_manager.clients, {})
self.assertDictEqual(server.client_manager.start_times, {})
class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.server.client_manager.start_times = {
'client1': time.time() - 120,
'client2': time.time() - 300
}
self.server.client_manager.max_connection_time = 600
def test_get_wait_time(self):
expected_wait_time = (600 - (time.time() - self.server.client_manager.start_times['client2'])) / 60
print(self.server.client_manager.get_wait_time(), expected_wait_time)
self.assertAlmostEqual(self.server.client_manager.get_wait_time(), expected_wait_time, places=2)
class TestServerConnection(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection(self, mock_websocket):
mock_websocket.recv.return_value = json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = [json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
}), np.array([1, 2, 3]).tobytes()]
with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertNotIn(mock_websocket, self.server.client_manager.clients)
class TestServerInferenceAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio')
cls.mock_pyaudio = cls.mock_pyaudio_patch.start()
cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock()
cls.server_process = subprocess.Popen(["python", "run_server.py"])
time.sleep(2)
@classmethod
def tearDownClass(cls):
cls.server_process.terminate()
cls.server_process.wait()
def setUp(self):
self.normalizer = EnglishTextNormalizer()
def check_prediction(self, srt_path):
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
with open(srt_path, "r") as f:
lines = f.readlines()
prediction = " ".join([line.strip() for line in lines[2::4]])
prediction_normalized = self.normalizer(prediction)
gt_normalized = self.normalizer(gt)
# calculate WER
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
self.assertLess(wer_score, 0.05)
def test_inference(self):
client = TranscriptionClient(
"localhost", "9090", model="base.en", lang="en",
)
client("assets/jfk.flac")
self.check_prediction("output.srt")
def test_simultaneous_inference(self):
client1 = Client(
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt")
client2 = Client(
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt")
tee = TranscriptionTeeClient([client1, client2])
tee("assets/jfk.flac")
self.check_prediction("transcript1.srt")
self.check_prediction("transcript2.srt")
class TestExceptionHandling(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection_closed_exception(self, mock_websocket):
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())
with self.assertLogs(level="INFO") as log:
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Connection closed by client" in message for message in log.output))
@mock.patch('websockets.WebSocketCommonProtocol')
def test_json_decode_exception(self, mock_websocket):
mock_websocket.recv.return_value = "invalid json"
with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))
@mock.patch('websockets.WebSocketCommonProtocol')
def test_unexpected_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")
with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
for message in log.output:
print(message)
print()
self.assertTrue(any("Unexpected error" in message for message in log.output))