import json import os import scipy import websocket import copy import unittest from unittest.mock import patch, MagicMock from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient from whisper_live.utils import resample from pathlib import Path class BaseTestCase(unittest.TestCase): @patch('whisper_live.client.websocket.WebSocketApp') @patch('whisper_live.client.pyaudio.PyAudio') def setUp(self, mock_pyaudio, mock_websocket): self.mock_pyaudio_instance = MagicMock() mock_pyaudio.return_value = self.mock_pyaudio_instance self.mock_stream = MagicMock() self.mock_pyaudio_instance.open.return_value = self.mock_stream self.mock_ws_app = mock_websocket.return_value self.mock_ws_app.send = MagicMock() self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client self.mock_pyaudio = mock_pyaudio self.mock_websocket = mock_websocket self.mock_audio_packet = b'\x00\x01\x02\x03' def tearDown(self): self.client.close_websocket() self.mock_pyaudio.stop() self.mock_websocket.stop() del self.client class TestClientWebSocketCommunication(BaseTestCase): def test_websocket_communication(self): expected_url = 'ws://localhost:9090' self.mock_websocket.assert_called() self.assertEqual(self.mock_websocket.call_args[0][0], expected_url) class TestClientCallbacks(BaseTestCase): def test_on_open(self): expected_message = json.dumps({ "uid": self.client.uid, "language": self.client.language, "task": self.client.task, "model": self.client.model, "use_vad": True, "max_clients": 4, "max_connection_time": 600, "send_last_n_segments": 10, "no_speech_thresh": 0.45, "clip_audio": False, "same_output_threshold": 10, }) self.client.on_open(self.mock_ws_app) self.mock_ws_app.send.assert_called_with(expected_message) def test_on_message(self): message = json.dumps( { "uid": self.client.uid, "message": "SERVER_READY", "backend": "faster_whisper" } ) self.client.on_message(self.mock_ws_app, message) message = json.dumps({ "uid": self.client.uid, "segments": [ {"start": 0, "end": 1, "text": "Test transcript", "completed": True}, {"start": 1, "end": 2, "text": "Test transcript 2", "completed": True}, {"start": 2, "end": 3, "text": "Test transcript 3", "completed": True} ] }) self.client.on_message(self.mock_ws_app, message) # Assert that the transcript was updated correctly self.assertEqual(len(self.client.transcript), 3) self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") def test_on_close(self): close_status_code = 1000 close_msg = "Normal closure" self.client.on_close(self.mock_ws_app, close_status_code, close_msg) self.assertFalse(self.client.recording) self.assertFalse(self.client.server_error) self.assertFalse(self.client.waiting) def test_on_error(self): error_message = "Test Error" self.client.on_error(self.mock_ws_app, error_message) self.assertTrue(self.client.server_error) self.assertEqual(self.client.error_message, error_message) class TestAudioResampling(unittest.TestCase): def test_resample_audio(self): original_audio = "assets/jfk.flac" expected_sr = 16000 resampled_audio = resample(original_audio, expected_sr) sr, _ = scipy.io.wavfile.read(resampled_audio) self.assertEqual(sr, expected_sr) os.remove(resampled_audio) class TestSendingAudioPacket(BaseTestCase): def test_send_packet(self): self.client.send_packet_to_server(self.mock_audio_packet) self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) class TestTee(BaseTestCase): @patch('whisper_live.client.websocket.WebSocketApp') @patch('whisper_live.client.pyaudio.PyAudio') def setUp(self, mock_audio, mock_websocket): super().setUp() self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt") self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt") # need a separate mock for each websocket self.client3.client_socket = copy.deepcopy(self.client3.client_socket) self.tee = TranscriptionTeeClient([self.client2, self.client3]) def tearDown(self): self.tee.close_all_clients() del self.tee super().tearDown() def test_invalid_constructor(self): with self.assertRaises(Exception) as context: TranscriptionTeeClient([]) def test_multicast_unconditional(self): self.tee.multicast_packet(self.mock_audio_packet, True) for client in self.tee.clients: client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) def test_multicast_conditional(self): self.client2.recording = False self.client3.recording = True self.tee.multicast_packet(self.mock_audio_packet, False) self.client2.client_socket.send.assert_not_called() self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) def test_close_all(self): self.tee.close_all_clients() for client in self.tee.clients: client.client_socket.close.assert_called() def test_write_all_srt(self): for client in self.tee.clients: client.server_backend = "faster_whisper" self.tee.write_all_clients_srt() self.assertTrue(Path("transcript.srt").is_file()) self.assertTrue(Path("translation.srt").is_file())