File size: 6,148 Bytes
7222c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
import os
import scipy
import websocket
import copy
import unittest
from unittest.mock import patch, MagicMock
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper_live.utils import resample
from pathlib import Path


class BaseTestCase(unittest.TestCase):
    @patch('whisper_live.client.websocket.WebSocketApp')
    @patch('whisper_live.client.pyaudio.PyAudio')
    def setUp(self, mock_pyaudio, mock_websocket):
        self.mock_pyaudio_instance = MagicMock()
        mock_pyaudio.return_value = self.mock_pyaudio_instance
        self.mock_stream = MagicMock()
        self.mock_pyaudio_instance.open.return_value = self.mock_stream

        self.mock_ws_app = mock_websocket.return_value
        self.mock_ws_app.send = MagicMock()

        self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client

        self.mock_pyaudio = mock_pyaudio
        self.mock_websocket = mock_websocket
        self.mock_audio_packet = b'\x00\x01\x02\x03'

    def tearDown(self):
        self.client.close_websocket()
        self.mock_pyaudio.stop()
        self.mock_websocket.stop()
        del self.client

class TestClientWebSocketCommunication(BaseTestCase):
    def test_websocket_communication(self):
        expected_url = 'ws://localhost:9090'
        self.mock_websocket.assert_called()
        self.assertEqual(self.mock_websocket.call_args[0][0], expected_url)


class TestClientCallbacks(BaseTestCase):
    def test_on_open(self):
        expected_message = json.dumps({
            "uid": self.client.uid,
            "language": self.client.language,
            "task": self.client.task,
            "model": self.client.model,
            "use_vad": True,
            "max_clients": 4,
            "max_connection_time": 600,
            "send_last_n_segments": 10,
            "no_speech_thresh": 0.45,
            "clip_audio": False,
            "same_output_threshold": 10,
        })
        self.client.on_open(self.mock_ws_app)
        self.mock_ws_app.send.assert_called_with(expected_message)

    def test_on_message(self):
        message = json.dumps(
            {
                "uid": self.client.uid,
                "message": "SERVER_READY",
                "backend": "faster_whisper"
            }
        )
        self.client.on_message(self.mock_ws_app, message)

        message = json.dumps({
            "uid": self.client.uid,
            "segments": [
                {"start": 0, "end": 1, "text": "Test transcript", "completed": True},
                {"start": 1, "end": 2, "text": "Test transcript 2", "completed": True},
                {"start": 2, "end": 3, "text": "Test transcript 3", "completed": True}
            ]
        })
        self.client.on_message(self.mock_ws_app, message)

        # Assert that the transcript was updated correctly
        self.assertEqual(len(self.client.transcript), 3)
        self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")

    def test_on_close(self):
        close_status_code = 1000
        close_msg = "Normal closure"
        self.client.on_close(self.mock_ws_app, close_status_code, close_msg)

        self.assertFalse(self.client.recording)
        self.assertFalse(self.client.server_error)
        self.assertFalse(self.client.waiting)

    def test_on_error(self):
        error_message = "Test Error"
        self.client.on_error(self.mock_ws_app, error_message)

        self.assertTrue(self.client.server_error)
        self.assertEqual(self.client.error_message, error_message)


class TestAudioResampling(unittest.TestCase):
    def test_resample_audio(self):
        original_audio = "assets/jfk.flac"
        expected_sr = 16000
        resampled_audio = resample(original_audio, expected_sr)

        sr, _ = scipy.io.wavfile.read(resampled_audio)
        self.assertEqual(sr, expected_sr)

        os.remove(resampled_audio)


class TestSendingAudioPacket(BaseTestCase):
    def test_send_packet(self):
        self.client.send_packet_to_server(self.mock_audio_packet)
        self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

class TestTee(BaseTestCase):
    @patch('whisper_live.client.websocket.WebSocketApp')
    @patch('whisper_live.client.pyaudio.PyAudio')
    def setUp(self, mock_audio, mock_websocket):
        super().setUp()
        self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt")
        self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt")
        # need a separate mock for each websocket
        self.client3.client_socket = copy.deepcopy(self.client3.client_socket)
        self.tee = TranscriptionTeeClient([self.client2, self.client3])

    def tearDown(self):
        self.tee.close_all_clients()
        del self.tee
        super().tearDown()

    def test_invalid_constructor(self):
        with self.assertRaises(Exception) as context:
            TranscriptionTeeClient([])

    def test_multicast_unconditional(self):
        self.tee.multicast_packet(self.mock_audio_packet, True)
        for client in self.tee.clients:
            client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

    def test_multicast_conditional(self):
        self.client2.recording = False
        self.client3.recording = True
        self.tee.multicast_packet(self.mock_audio_packet, False)
        self.client2.client_socket.send.assert_not_called()
        self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

    def test_close_all(self):
        self.tee.close_all_clients()
        for client in self.tee.clients:
            client.client_socket.close.assert_called()

    def test_write_all_srt(self):
        for client in self.tee.clients:
            client.server_backend = "faster_whisper"
        self.tee.write_all_clients_srt()
        self.assertTrue(Path("transcript.srt").is_file())
        self.assertTrue(Path("translation.srt").is_file())