|
import json |
|
import os |
|
import scipy |
|
import websocket |
|
import copy |
|
import unittest |
|
from unittest.mock import patch, MagicMock |
|
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient |
|
from whisper_live.utils import resample |
|
from pathlib import Path |
|
|
|
|
|
class BaseTestCase(unittest.TestCase): |
|
@patch('whisper_live.client.websocket.WebSocketApp') |
|
@patch('whisper_live.client.pyaudio.PyAudio') |
|
def setUp(self, mock_pyaudio, mock_websocket): |
|
self.mock_pyaudio_instance = MagicMock() |
|
mock_pyaudio.return_value = self.mock_pyaudio_instance |
|
self.mock_stream = MagicMock() |
|
self.mock_pyaudio_instance.open.return_value = self.mock_stream |
|
|
|
self.mock_ws_app = mock_websocket.return_value |
|
self.mock_ws_app.send = MagicMock() |
|
|
|
self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client |
|
|
|
self.mock_pyaudio = mock_pyaudio |
|
self.mock_websocket = mock_websocket |
|
self.mock_audio_packet = b'\x00\x01\x02\x03' |
|
|
|
def tearDown(self): |
|
self.client.close_websocket() |
|
self.mock_pyaudio.stop() |
|
self.mock_websocket.stop() |
|
del self.client |
|
|
|
class TestClientWebSocketCommunication(BaseTestCase): |
|
def test_websocket_communication(self): |
|
expected_url = 'ws://localhost:9090' |
|
self.mock_websocket.assert_called() |
|
self.assertEqual(self.mock_websocket.call_args[0][0], expected_url) |
|
|
|
|
|
class TestClientCallbacks(BaseTestCase): |
|
def test_on_open(self): |
|
expected_message = json.dumps({ |
|
"uid": self.client.uid, |
|
"language": self.client.language, |
|
"task": self.client.task, |
|
"model": self.client.model, |
|
"use_vad": True, |
|
"max_clients": 4, |
|
"max_connection_time": 600, |
|
"send_last_n_segments": 10, |
|
"no_speech_thresh": 0.45, |
|
"clip_audio": False, |
|
"same_output_threshold": 10, |
|
}) |
|
self.client.on_open(self.mock_ws_app) |
|
self.mock_ws_app.send.assert_called_with(expected_message) |
|
|
|
def test_on_message(self): |
|
message = json.dumps( |
|
{ |
|
"uid": self.client.uid, |
|
"message": "SERVER_READY", |
|
"backend": "faster_whisper" |
|
} |
|
) |
|
self.client.on_message(self.mock_ws_app, message) |
|
|
|
message = json.dumps({ |
|
"uid": self.client.uid, |
|
"segments": [ |
|
{"start": 0, "end": 1, "text": "Test transcript", "completed": True}, |
|
{"start": 1, "end": 2, "text": "Test transcript 2", "completed": True}, |
|
{"start": 2, "end": 3, "text": "Test transcript 3", "completed": True} |
|
] |
|
}) |
|
self.client.on_message(self.mock_ws_app, message) |
|
|
|
|
|
self.assertEqual(len(self.client.transcript), 3) |
|
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") |
|
|
|
def test_on_close(self): |
|
close_status_code = 1000 |
|
close_msg = "Normal closure" |
|
self.client.on_close(self.mock_ws_app, close_status_code, close_msg) |
|
|
|
self.assertFalse(self.client.recording) |
|
self.assertFalse(self.client.server_error) |
|
self.assertFalse(self.client.waiting) |
|
|
|
def test_on_error(self): |
|
error_message = "Test Error" |
|
self.client.on_error(self.mock_ws_app, error_message) |
|
|
|
self.assertTrue(self.client.server_error) |
|
self.assertEqual(self.client.error_message, error_message) |
|
|
|
|
|
class TestAudioResampling(unittest.TestCase): |
|
def test_resample_audio(self): |
|
original_audio = "assets/jfk.flac" |
|
expected_sr = 16000 |
|
resampled_audio = resample(original_audio, expected_sr) |
|
|
|
sr, _ = scipy.io.wavfile.read(resampled_audio) |
|
self.assertEqual(sr, expected_sr) |
|
|
|
os.remove(resampled_audio) |
|
|
|
|
|
class TestSendingAudioPacket(BaseTestCase): |
|
def test_send_packet(self): |
|
self.client.send_packet_to_server(self.mock_audio_packet) |
|
self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) |
|
|
|
class TestTee(BaseTestCase): |
|
@patch('whisper_live.client.websocket.WebSocketApp') |
|
@patch('whisper_live.client.pyaudio.PyAudio') |
|
def setUp(self, mock_audio, mock_websocket): |
|
super().setUp() |
|
self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt") |
|
self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt") |
|
|
|
self.client3.client_socket = copy.deepcopy(self.client3.client_socket) |
|
self.tee = TranscriptionTeeClient([self.client2, self.client3]) |
|
|
|
def tearDown(self): |
|
self.tee.close_all_clients() |
|
del self.tee |
|
super().tearDown() |
|
|
|
def test_invalid_constructor(self): |
|
with self.assertRaises(Exception) as context: |
|
TranscriptionTeeClient([]) |
|
|
|
def test_multicast_unconditional(self): |
|
self.tee.multicast_packet(self.mock_audio_packet, True) |
|
for client in self.tee.clients: |
|
client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) |
|
|
|
def test_multicast_conditional(self): |
|
self.client2.recording = False |
|
self.client3.recording = True |
|
self.tee.multicast_packet(self.mock_audio_packet, False) |
|
self.client2.client_socket.send.assert_not_called() |
|
self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) |
|
|
|
def test_close_all(self): |
|
self.tee.close_all_clients() |
|
for client in self.tee.clients: |
|
client.client_socket.close.assert_called() |
|
|
|
def test_write_all_srt(self): |
|
for client in self.tee.clients: |
|
client.server_backend = "faster_whisper" |
|
self.tee.write_all_clients_srt() |
|
self.assertTrue(Path("transcript.srt").is_file()) |
|
self.assertTrue(Path("translation.srt").is_file()) |
|
|