nuernie commited on
Commit
7222c68
·
1 Parent(s): d17f6c7

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Editors
2
+ .vscode/
3
+ .idea/
4
+
5
+ # Vagrant
6
+ .vagrant/
7
+
8
+ # Mac/OSX
9
+ .DS_Store
10
+
11
+ # Windows
12
+ Thumbs.db
13
+
14
+ # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .nox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ .python-version
96
+
97
+ # celery beat schedule file
98
+ celerybeat-schedule
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ docs/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && \
6
+ apt-get install -y portaudio19-dev python3-dev gcc && \
7
+ apt-get clean && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt .
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Vineet Suryan, Collabora Ltd.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,209 @@
1
- ---
2
- title: Ai Server
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- short_description: Answers interview questions
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WhisperLive
2
+
3
+ <h2 align="center">
4
+ <a href="https://www.youtube.com/watch?v=0PHWCApIcCI"><img
5
+ src="https://img.youtube.com/vi/0PHWCApIcCI/0.jpg" style="background-color:rgba(0,0,0,0);" height=300 alt="WhisperLive"></a>
6
+ <br><br>A nearly-live implementation of OpenAI's Whisper.
7
+ <br><br>
8
+ </h2>
9
+
10
+ This project is a real-time transcription application that uses the OpenAI Whisper model
11
+ to convert speech input into text output. It can be used to transcribe both live audio
12
+ input from microphone and pre-recorded audio files.
13
+
14
+ - [Installation](#installation)
15
+ - [Getting Started](#getting-started)
16
+ - [Running the Server](#running-the-server)
17
+ - [Running the Client](#running-the-client)
18
+ - [Browser Extensions](#browser-extensions)
19
+ - [Whisper Live Server in Docker](#whisper-live-server-in-docker)
20
+ - [Future Work](#future-work)
21
+ - [Contact](#contact)
22
+ - [Citations](#citations)
23
+
24
+ ## Installation
25
+ - Install PyAudio
26
+ ```bash
27
+ bash scripts/setup.sh
28
+ ```
29
+
30
+ - Install whisper-live from pip
31
+ ```bash
32
+ pip install whisper-live
33
+ ```
34
+
35
+ ### Setting up NVIDIA/TensorRT-LLM for TensorRT backend
36
+ - Please follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup of [NVIDIA/TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and for building Whisper-TensorRT engine.
37
+
38
+ ## Getting Started
39
+ The server supports 3 backends `faster_whisper`, `tensorrt` and `openvino`. If running `tensorrt` backend follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md)
40
+
41
+ ### Running the Server
42
+ - [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) backend
43
+ ```bash
44
+ python3 run_server.py --port 9090 \
45
+ --backend faster_whisper
46
+
47
+ # running with custom model
48
+ python3 run_server.py --port 9090 \
49
+ --backend faster_whisper \
50
+ -fw "/path/to/custom/faster/whisper/model"
51
+ ```
52
+
53
+ - TensorRT backend. Currently, we recommend to only use the docker setup for TensorRT. Follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) which works as expected. Make sure to build your TensorRT Engines before running the server with TensorRT backend.
54
+ ```bash
55
+ # Run English only model
56
+ python3 run_server.py -p 9090 \
57
+ -b tensorrt \
58
+ -trt /home/TensorRT-LLM/examples/whisper/whisper_small_en
59
+
60
+ # Run Multilingual model
61
+ python3 run_server.py -p 9090 \
62
+ -b tensorrt \
63
+ -trt /home/TensorRT-LLM/examples/whisper/whisper_small \
64
+ -m
65
+ ```
66
+
67
+ - WhisperLive now supports the [OpenVINO](https://github.com/openvinotoolkit/openvino) backend for efficient inference on Intel CPUs, iGPU and dGPUs. Currently, we tested the models uploaded to [huggingface by OpenVINO](https://huggingface.co/OpenVINO?search_models=whisper).
68
+ - > **Docker Recommended:** Running WhisperLive with OpenVINO inside Docker automatically enables GPU support (iGPU/dGPU) without requiring additional host setup.
69
+ - > **Native (non-Docker) Use:** If you prefer running outside Docker, ensure the Intel drivers and OpenVINO runtime are installed and properly configured on your system. Refer to the documentation for [installing OpenVINO](https://docs.openvino.ai/2025/get-started/install-openvino.html?PACKAGE=OPENVINO_BASE&VERSION=v_2025_0_0&OP_SYSTEM=LINUX&DISTRIBUTION=PIP#).
70
+
71
+ ```
72
+ python3 run_server.py -p 9090 -b openvino
73
+ ```
74
+
75
+
76
+ #### Controlling OpenMP Threads
77
+ To control the number of threads used by OpenMP, you can set the `OMP_NUM_THREADS` environment variable. This is useful for managing CPU resources and ensuring consistent performance. If not specified, `OMP_NUM_THREADS` is set to `1` by default. You can change this by using the `--omp_num_threads` argument:
78
+ ```bash
79
+ python3 run_server.py --port 9090 \
80
+ --backend faster_whisper \
81
+ --omp_num_threads 4
82
+ ```
83
+
84
+ #### Single model mode
85
+ By default, when running the server without specifying a model, the server will instantiate a new whisper model for every client connection. This has the advantage, that the server can use different model sizes, based on the client's requested model size. On the other hand, it also means you have to wait for the model to be loaded upon client connection and you will have increased (V)RAM usage.
86
+
87
+ When serving a custom TensorRT model using the `-trt` or a custom faster_whisper model using the `-fw` option, the server will instead only instantiate the custom model once and then reuse it for all client connections.
88
+
89
+ If you don't want this, set `--no_single_model`.
90
+
91
+
92
+ ### Running the Client
93
+ - Initializing the client with below parameters:
94
+ - `lang`: Language of the input audio, applicable only if using a multilingual model.
95
+ - `translate`: If set to `True` then translate from any language to `en`.
96
+ - `model`: Whisper model size.
97
+ - `use_vad`: Whether to use `Voice Activity Detection` on the server.
98
+ - `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
99
+ - `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
100
+ - `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4.
101
+ - `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600.
102
+ - `mute_audio_playback`: Whether to mute audio playback when transcribing an audio file. Defaults to False.
103
+
104
+ ```python
105
+ from whisper_live.client import TranscriptionClient
106
+ client = TranscriptionClient(
107
+ "localhost",
108
+ 9090,
109
+ lang="en",
110
+ translate=False,
111
+ model="small", # also support hf_model => `Systran/faster-whisper-small`
112
+ use_vad=False,
113
+ save_output_recording=True, # Only used for microphone input, False by Default
114
+ output_recording_filename="./output_recording.wav", # Only used for microphone input
115
+ max_clients=4,
116
+ max_connection_time=600,
117
+ mute_audio_playback=False, # Only used for file input, False by Default
118
+ )
119
+ ```
120
+ It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
121
+
122
+ - Transcribe an audio file:
123
+ ```python
124
+ client("tests/jfk.wav")
125
+ ```
126
+
127
+ - To transcribe from microphone:
128
+ ```python
129
+ client()
130
+ ```
131
+
132
+ - To transcribe from a RTSP stream:
133
+ ```python
134
+ client(rtsp_url="rtsp://admin:[email protected]/rtsp")
135
+ ```
136
+
137
+ - To transcribe from a HLS stream:
138
+ ```python
139
+ client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/bbc_1xtra.isml/bbc_1xtra-audio%3d96000.norewind.m3u8")
140
+ ```
141
+
142
+ ## Browser Extensions
143
+ - Run the server with your desired backend as shown [here](https://github.com/collabora/WhisperLive?tab=readme-ov-file#running-the-server).
144
+ - Transcribe audio directly from your browser using our Chrome or Firefox extensions. Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) and https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md
145
+
146
+ ## Whisper Live Server in Docker
147
+ - GPU
148
+ - Faster-Whisper
149
+ ```bash
150
+ docker run -it --gpus all -p 9090:9090 ghcr.io/collabora/whisperlive-gpu:latest
151
+ ```
152
+
153
+ - TensorRT. Refer to [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup and more tensorrt backend configurations.
154
+ ```bash
155
+ docker build . -f docker/Dockerfile.tensorrt -t whisperlive-tensorrt
156
+ docker run -p 9090:9090 --runtime=nvidia --gpus all --entrypoint /bin/bash -it whisperlive-tensorrt
157
+
158
+ # Build small.en engine
159
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en # float16
160
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int8 # int8 weight only quantization
161
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int4 # int4 weight only quantization
162
+
163
+ # Run server with small.en
164
+ python3 run_server.py --port 9090 \
165
+ --backend tensorrt \
166
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_float16"
167
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int8"
168
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int4"
169
+ ```
170
+
171
+ - OpenVINO
172
+ ```
173
+ docker run -it --device=/dev/dri -p 9090:9090 ghcr.io/collabora/whisperlive-openvino
174
+ ```
175
+
176
+ - CPU
177
+ - Faster-whisper
178
+ ```bash
179
+ docker run -it -p 9090:9090 ghcr.io/collabora/whisperlive-cpu:latest
180
+ ```
181
+
182
+ ## Future Work
183
+ - [ ] Add translation to other languages on top of transcription.
184
+
185
+ ## Contact
186
+
187
+ We are available to help you with both Open Source and proprietary AI projects. You can reach us via the Collabora website or [[email protected]](mailto:[email protected]) and [[email protected]](mailto:[email protected]).
188
+
189
+ ## Citations
190
+ ```bibtex
191
+ @article{Whisper
192
+ title = {Robust Speech Recognition via Large-Scale Weak Supervision},
193
+ url = {https://arxiv.org/abs/2212.04356},
194
+ author = {Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
195
+ publisher = {arXiv},
196
+ year = {2022},
197
+ }
198
+ ```
199
+
200
+ ```bibtex
201
+ @misc{Silero VAD,
202
+ author = {Silero Team},
203
+ title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
204
+ year = {2021},
205
+ publisher = {GitHub},
206
+ journal = {GitHub repository},
207
+ howpublished = {\url{https://github.com/snakers4/silero-vad}},
208
+ email = {[email protected]}
209
+ }
TensorRT_whisper.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WhisperLive-TensorRT
2
+ We have only tested the TensorRT backend in docker so, we recommend docker for a smooth TensorRT backend setup.
3
+ **Note**: We use `tensorrt_llm==0.18.2`
4
+
5
+ ## Installation
6
+ - Install [docker](https://docs.docker.com/engine/install/)
7
+ - Install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
8
+
9
+ - Run WhisperLive TensorRT in docker
10
+ ```bash
11
+ docker build . -f docker/Dockerfile.tensorrt -t whisperlive-tensorrt
12
+ docker run -p 9090:9090 --runtime=nvidia --gpus all --entrypoint /bin/bash -it whisperlive-tensorrt
13
+ ```
14
+
15
+ ## Whisper TensorRT Engine
16
+ - We build `small.en` and `small` multilingual TensorRT engine as examples below. The script logs the path of the directory with Whisper TensorRT engine. We need that model_path to run the server.
17
+ ```bash
18
+ # convert small.en
19
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en # float16
20
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int8 # int8 weight only quantization
21
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int4 # int4 weight only quantization
22
+
23
+ # convert small multilingual model
24
+ bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small
25
+ ```
26
+
27
+ ## Run WhisperLive Server with TensorRT Backend
28
+ ```bash
29
+ # Run English only model
30
+ python3 run_server.py --port 9090 \
31
+ --backend tensorrt \
32
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_float16"
33
+
34
+ # Run Multilingual model
35
+ python3 run_server.py --port 9090 \
36
+ --backend tensorrt \
37
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
38
+ --trt_multilingual
39
+ ```
40
+
41
+ By default trt_backend uses cpp_session, to use python session pass `--trt_py_session` to run_server.py
42
+ ```bash
43
+ python3 run_server.py --port 9090 \
44
+ --backend tensorrt \
45
+ --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
46
+ --trt_py_session
47
+ ```
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import uvicorn
3
+ from whisper_live.server import TranscriptionServer
4
+
5
+ app = FastAPI(title="Whisper Live Server")
6
+
7
+ @app.on_event("startup")
8
+ async def startup_event():
9
+ # Start the transcription server in the background
10
+ server = TranscriptionServer()
11
+ server.run(
12
+ host="0.0.0.0",
13
+ port=7860, # Hugging Face Spaces uses port 7860
14
+ backend="faster_whisper", # Using faster_whisper as the backend
15
+ single_model=True # Use single model mode for better resource usage
16
+ )
17
+
18
+ @app.get("/health")
19
+ def health_check():
20
+ return {"status": "healthy"}
21
+
22
+ if __name__ == "__main__":
23
+ uvicorn.run(app, host="0.0.0.0", port=7860)
docker/Dockerfile.cpu ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-bookworm
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ # install lib required for pyaudio
6
+ RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/*
7
+
8
+ # update pip to support for whl.metadata -> less downloading
9
+ RUN pip install --no-cache-dir -U "pip>=24"
10
+
11
+ # create a working directory
12
+ RUN mkdir /app
13
+ WORKDIR /app
14
+
15
+ # install pytorch, but without the nvidia-libs that are only necessary for gpu
16
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
17
+
18
+ # install the requirements for running the whisper-live server
19
+ COPY requirements/server.txt /app/
20
+ RUN pip install --no-cache-dir -r server.txt && rm server.txt
21
+
22
+ COPY whisper_live /app/whisper_live
23
+ COPY run_server.py /app
24
+
25
+ CMD ["python", "run_server.py"]
docker/Dockerfile.gpu ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-bookworm
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ # install lib required for pyaudio
6
+ RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/*
7
+
8
+ # update pip to support for whl.metadata -> less downloading
9
+ RUN pip install --no-cache-dir -U "pip>=24"
10
+
11
+ # create a working directory
12
+ RUN mkdir /app
13
+ WORKDIR /app
14
+
15
+ # install the requirements for running the whisper-live server
16
+ COPY requirements/server.txt /app/
17
+ RUN pip install --no-cache-dir -r server.txt && rm server.txt
18
+
19
+ # make the paths of the nvidia libs installed as wheels visible. equivalent to:
20
+ # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
21
+ ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib"
22
+
23
+ COPY whisper_live /app/whisper_live
24
+ COPY run_server.py /app
25
+
26
+ CMD ["python", "run_server.py"]
docker/Dockerfile.openvino ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM openvino/ubuntu22_runtime:latest
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ USER root
6
+
7
+ RUN apt update && apt install -y portaudio19-dev python-is-python3 && apt-get clean && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN pip install --no-cache-dir -U "pip>=24"
10
+
11
+ RUN mkdir /app
12
+ WORKDIR /app
13
+
14
+ COPY requirements/server.txt /app/
15
+ RUN pip install --no-cache-dir -r server.txt && rm server.txt
16
+
17
+ COPY whisper_live /app/whisper_live
18
+ COPY run_server.py /app
19
+ CMD ["python", "run_server.py", "--backend", "openvino"]
docker/Dockerfile.tensorrt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS base
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs wget \
7
+ && apt install python-is-python3 \
8
+ && pip install --upgrade pip setuptools \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ FROM base AS devel
12
+ RUN pip install --no-cache-dir -U tensorrt_llm==0.18.2 --extra-index-url https://pypi.nvidia.com
13
+ WORKDIR /app
14
+ RUN git clone -b v0.18.2 https://github.com/NVIDIA/TensorRT-LLM.git \
15
+ && mv TensorRT-LLM/examples ./TensorRT-LLM-examples \
16
+ && rm -rf TensorRT-LLM
17
+
18
+ FROM devel AS release
19
+ WORKDIR /app
20
+ COPY assets/ ./assets
21
+ RUN wget -nc -P assets/ https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz
22
+
23
+ COPY scripts/setup.sh ./
24
+ RUN apt update && bash setup.sh && rm setup.sh
25
+
26
+ COPY requirements/server.txt .
27
+ RUN pip install --no-cache-dir -r server.txt && rm server.txt
28
+ COPY whisper_live ./whisper_live
29
+ COPY scripts/build_whisper_tensorrt.sh .
30
+ COPY run_server.py .
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ faster-whisper
2
+ numpy
3
+ websockets
4
+ pyaudio
5
+ soundfile
6
+ torch
7
+ torchaudio
requirements/client.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PyAudio
2
+ av
3
+ scipy
4
+ websocket-client
requirements/server.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faster-whisper==1.1.0
2
+ websockets
3
+ onnxruntime==1.17.0
4
+ numba
5
+ kaldialign
6
+ soundfile
7
+ scipy
8
+ av
9
+ jiwer
10
+ evaluate
11
+ numpy<2
12
+ openai-whisper==20240930
13
+ tokenizers==0.20.3
14
+
15
+ # openvino
16
+ librosa
17
+ openvino
18
+ openvino-genai
19
+ openvino-tokenizers
20
+ optimum
21
+ optimum-intel
run_server.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ if __name__ == "__main__":
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument('--port', '-p',
7
+ type=int,
8
+ default=9090,
9
+ help="Websocket port to run the server on.")
10
+ parser.add_argument('--backend', '-b',
11
+ type=str,
12
+ default='faster_whisper',
13
+ help='Backends from ["tensorrt", "faster_whisper", "openvino"]')
14
+ parser.add_argument('--faster_whisper_custom_model_path', '-fw',
15
+ type=str, default=None,
16
+ help="Custom Faster Whisper Model")
17
+ parser.add_argument('--trt_model_path', '-trt',
18
+ type=str,
19
+ default=None,
20
+ help='Whisper TensorRT model path')
21
+ parser.add_argument('--trt_multilingual', '-m',
22
+ action="store_true",
23
+ help='Boolean only for TensorRT model. True if multilingual.')
24
+ parser.add_argument('--trt_py_session',
25
+ action="store_true",
26
+ help='Boolean only for TensorRT model. Use python session or cpp session, By default uses Cpp.')
27
+ parser.add_argument('--omp_num_threads', '-omp',
28
+ type=int,
29
+ default=1,
30
+ help="Number of threads to use for OpenMP")
31
+ parser.add_argument('--no_single_model', '-nsm',
32
+ action='store_true',
33
+ help='Set this if every connection should instantiate its own model. Only relevant for custom model, passed using -trt or -fw.')
34
+ args = parser.parse_args()
35
+
36
+ if args.backend == "tensorrt":
37
+ if args.trt_model_path is None:
38
+ raise ValueError("Please Provide a valid tensorrt model path")
39
+
40
+ if "OMP_NUM_THREADS" not in os.environ:
41
+ os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads)
42
+
43
+ from whisper_live.server import TranscriptionServer
44
+ server = TranscriptionServer()
45
+ server.run(
46
+ "0.0.0.0",
47
+ port=args.port,
48
+ backend=args.backend,
49
+ faster_whisper_custom_model_path=args.faster_whisper_custom_model_path,
50
+ whisper_tensorrt_path=args.trt_model_path,
51
+ trt_multilingual=args.trt_multilingual,
52
+ trt_py_session=args.trt_py_session,
53
+ single_model=not args.no_single_model,
54
+ )
scripts/build_whisper_tensorrt.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ download_and_build_model() {
4
+ local model_name="$1"
5
+ local model_url=""
6
+
7
+ case "$model_name" in
8
+ "tiny.en")
9
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt"
10
+ ;;
11
+ "tiny")
12
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
13
+ ;;
14
+ "base.en")
15
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt"
16
+ ;;
17
+ "base")
18
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
19
+ ;;
20
+ "small.en")
21
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt"
22
+ ;;
23
+ "small")
24
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt"
25
+ ;;
26
+ "medium.en")
27
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt"
28
+ ;;
29
+ "medium")
30
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt"
31
+ ;;
32
+ "large-v1")
33
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt"
34
+ ;;
35
+ "large-v2")
36
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt"
37
+ ;;
38
+ "large-v3" | "large")
39
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt"
40
+ ;;
41
+ "large-v3-turbo" | "turbo")
42
+ model_url="https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt"
43
+ ;;
44
+ *)
45
+ echo "Invalid model name: $model_name"
46
+ exit 1
47
+ ;;
48
+ esac
49
+
50
+ if [ "$model_name" == "turbo" ]; then
51
+ model_name="large-v3-turbo"
52
+ fi
53
+
54
+ local inference_precision="float16"
55
+ local weight_only_precision="${2:-float16}"
56
+ local max_beam_width=4
57
+ local max_batch_size=4
58
+
59
+ echo "Downloading $model_name..."
60
+ # wget --directory-prefix=assets "$model_url"
61
+ # echo "Download completed: ${model_name}.pt"
62
+ if [ ! -f "assets/${model_name}.pt" ]; then
63
+ wget --directory-prefix=assets "$model_url"
64
+ echo "Download completed: ${model_name}.pt"
65
+ else
66
+ echo "${model_name}.pt already exists in assets directory."
67
+ fi
68
+
69
+ local sanitized_model_name="${model_name//./_}"
70
+ local checkpoint_dir="whisper_${sanitized_model_name}_weights_${weight_only_precision}"
71
+ local output_dir="whisper_${sanitized_model_name}_${weight_only_precision}"
72
+ echo "$output_dir"
73
+ echo "Converting model weights for $model_name..."
74
+ python3 convert_checkpoint.py \
75
+ $( [[ "$weight_only_precision" == "int8" || "$weight_only_precision" == "int4" ]] && echo "--use_weight_only --weight_only_precision $weight_only_precision" ) \
76
+ --output_dir "$checkpoint_dir" --model_name "$model_name"
77
+
78
+ echo "Building encoder for $model_name..."
79
+ trtllm-build \
80
+ --checkpoint_dir "${checkpoint_dir}/encoder" \
81
+ --output_dir "${output_dir}/encoder" \
82
+ --moe_plugin disable \
83
+ --max_batch_size "$max_batch_size" \
84
+ --gemm_plugin disable \
85
+ --bert_attention_plugin "$inference_precision" \
86
+ --max_input_len 3000 \
87
+ --max_seq_len 3000
88
+
89
+ echo "Building decoder for $model_name..."
90
+ trtllm-build \
91
+ --checkpoint_dir "${checkpoint_dir}/decoder" \
92
+ --output_dir "${output_dir}/decoder" \
93
+ --moe_plugin disable \
94
+ --max_beam_width "$max_beam_width" \
95
+ --max_batch_size "$max_batch_size" \
96
+ --max_seq_len 225 \
97
+ --max_input_len 32 \
98
+ --max_encoder_input_len 3000 \
99
+ --gemm_plugin "$inference_precision" \
100
+ --bert_attention_plugin "$inference_precision" \
101
+ --gpt_attention_plugin "$inference_precision"
102
+
103
+ echo "TensorRT LLM engine built for $model_name."
104
+ echo "========================================="
105
+ echo "Model is located at: $(pwd)/$output_dir"
106
+ }
107
+
108
+ if [ "$#" -lt 1 ]; then
109
+ echo "Usage: $0 <path-to-tensorrt-examples-dir> [model-name]"
110
+ exit 1
111
+ fi
112
+
113
+ tensorrt_examples_dir="$1"
114
+ model_name="${2:-small.en}"
115
+ weight_only_precision="${3:-float16}" # Default to float16 if not provided
116
+
117
+ cd $tensorrt_examples_dir/whisper
118
+ pip install --no-deps -r requirements.txt
119
+
120
+ download_and_build_model "$model_name" "$weight_only_precision"
scripts/setup.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ apt-get install portaudio19-dev wget -y
setup.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from setuptools import find_packages, setup
3
+ from whisper_live.__version__ import __version__
4
+
5
+
6
+ # The directory containing this file
7
+ HERE = pathlib.Path(__file__).parent
8
+
9
+ # The text of the README file
10
+ README = (HERE / "README.md").read_text()
11
+
12
+ # This call to setup() does all the work
13
+ setup(
14
+ name="whisper_live",
15
+ version=__version__,
16
+ description="A nearly-live implementation of OpenAI's Whisper.",
17
+ long_description=README,
18
+ long_description_content_type="text/markdown",
19
+ include_package_data=True,
20
+ url="https://github.com/collabora/WhisperLive",
21
+ author="Collabora Ltd",
22
+ author_email="[email protected]",
23
+ license="MIT",
24
+ classifiers=[
25
+ "Development Status :: 4 - Beta",
26
+ "Intended Audience :: Developers",
27
+ "Intended Audience :: Science/Research",
28
+ "License :: OSI Approved :: MIT License",
29
+ "Programming Language :: Python :: 3",
30
+ "Programming Language :: Python :: 3 :: Only",
31
+ "Programming Language :: Python :: 3.8",
32
+ "Programming Language :: Python :: 3.9",
33
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
34
+ ],
35
+ packages=find_packages(
36
+ exclude=(
37
+ "examples",
38
+ "Audio-Transcription-Chrome",
39
+ "Audio-Transcription-Firefox",
40
+ "requirements",
41
+ "whisper-finetuning"
42
+ )
43
+ ),
44
+ install_requires=[
45
+ "PyAudio",
46
+ "faster-whisper==1.1.0",
47
+ "torch",
48
+ "torchaudio",
49
+ "websockets",
50
+ "onnxruntime==1.17.0",
51
+ "scipy",
52
+ "websocket-client",
53
+ "numba",
54
+ "openai-whisper==20240930",
55
+ "kaldialign",
56
+ "soundfile",
57
+ "tokenizers==0.20.3",
58
+ "librosa",
59
+ "numpy==1.26.4",
60
+ "openvino",
61
+ "openvino-genai",
62
+ "openvino-tokenizers",
63
+ "optimum",
64
+ "optimum-intel",
65
+ ],
66
+ python_requires=">=3.9"
67
+ )
tests/__init__.py ADDED
File without changes
tests/test_client.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import scipy
4
+ import websocket
5
+ import copy
6
+ import unittest
7
+ from unittest.mock import patch, MagicMock
8
+ from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
9
+ from whisper_live.utils import resample
10
+ from pathlib import Path
11
+
12
+
13
+ class BaseTestCase(unittest.TestCase):
14
+ @patch('whisper_live.client.websocket.WebSocketApp')
15
+ @patch('whisper_live.client.pyaudio.PyAudio')
16
+ def setUp(self, mock_pyaudio, mock_websocket):
17
+ self.mock_pyaudio_instance = MagicMock()
18
+ mock_pyaudio.return_value = self.mock_pyaudio_instance
19
+ self.mock_stream = MagicMock()
20
+ self.mock_pyaudio_instance.open.return_value = self.mock_stream
21
+
22
+ self.mock_ws_app = mock_websocket.return_value
23
+ self.mock_ws_app.send = MagicMock()
24
+
25
+ self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client
26
+
27
+ self.mock_pyaudio = mock_pyaudio
28
+ self.mock_websocket = mock_websocket
29
+ self.mock_audio_packet = b'\x00\x01\x02\x03'
30
+
31
+ def tearDown(self):
32
+ self.client.close_websocket()
33
+ self.mock_pyaudio.stop()
34
+ self.mock_websocket.stop()
35
+ del self.client
36
+
37
+ class TestClientWebSocketCommunication(BaseTestCase):
38
+ def test_websocket_communication(self):
39
+ expected_url = 'ws://localhost:9090'
40
+ self.mock_websocket.assert_called()
41
+ self.assertEqual(self.mock_websocket.call_args[0][0], expected_url)
42
+
43
+
44
+ class TestClientCallbacks(BaseTestCase):
45
+ def test_on_open(self):
46
+ expected_message = json.dumps({
47
+ "uid": self.client.uid,
48
+ "language": self.client.language,
49
+ "task": self.client.task,
50
+ "model": self.client.model,
51
+ "use_vad": True,
52
+ "max_clients": 4,
53
+ "max_connection_time": 600,
54
+ "send_last_n_segments": 10,
55
+ "no_speech_thresh": 0.45,
56
+ "clip_audio": False,
57
+ "same_output_threshold": 10,
58
+ })
59
+ self.client.on_open(self.mock_ws_app)
60
+ self.mock_ws_app.send.assert_called_with(expected_message)
61
+
62
+ def test_on_message(self):
63
+ message = json.dumps(
64
+ {
65
+ "uid": self.client.uid,
66
+ "message": "SERVER_READY",
67
+ "backend": "faster_whisper"
68
+ }
69
+ )
70
+ self.client.on_message(self.mock_ws_app, message)
71
+
72
+ message = json.dumps({
73
+ "uid": self.client.uid,
74
+ "segments": [
75
+ {"start": 0, "end": 1, "text": "Test transcript", "completed": True},
76
+ {"start": 1, "end": 2, "text": "Test transcript 2", "completed": True},
77
+ {"start": 2, "end": 3, "text": "Test transcript 3", "completed": True}
78
+ ]
79
+ })
80
+ self.client.on_message(self.mock_ws_app, message)
81
+
82
+ # Assert that the transcript was updated correctly
83
+ self.assertEqual(len(self.client.transcript), 3)
84
+ self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")
85
+
86
+ def test_on_close(self):
87
+ close_status_code = 1000
88
+ close_msg = "Normal closure"
89
+ self.client.on_close(self.mock_ws_app, close_status_code, close_msg)
90
+
91
+ self.assertFalse(self.client.recording)
92
+ self.assertFalse(self.client.server_error)
93
+ self.assertFalse(self.client.waiting)
94
+
95
+ def test_on_error(self):
96
+ error_message = "Test Error"
97
+ self.client.on_error(self.mock_ws_app, error_message)
98
+
99
+ self.assertTrue(self.client.server_error)
100
+ self.assertEqual(self.client.error_message, error_message)
101
+
102
+
103
+ class TestAudioResampling(unittest.TestCase):
104
+ def test_resample_audio(self):
105
+ original_audio = "assets/jfk.flac"
106
+ expected_sr = 16000
107
+ resampled_audio = resample(original_audio, expected_sr)
108
+
109
+ sr, _ = scipy.io.wavfile.read(resampled_audio)
110
+ self.assertEqual(sr, expected_sr)
111
+
112
+ os.remove(resampled_audio)
113
+
114
+
115
+ class TestSendingAudioPacket(BaseTestCase):
116
+ def test_send_packet(self):
117
+ self.client.send_packet_to_server(self.mock_audio_packet)
118
+ self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
119
+
120
+ class TestTee(BaseTestCase):
121
+ @patch('whisper_live.client.websocket.WebSocketApp')
122
+ @patch('whisper_live.client.pyaudio.PyAudio')
123
+ def setUp(self, mock_audio, mock_websocket):
124
+ super().setUp()
125
+ self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt")
126
+ self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt")
127
+ # need a separate mock for each websocket
128
+ self.client3.client_socket = copy.deepcopy(self.client3.client_socket)
129
+ self.tee = TranscriptionTeeClient([self.client2, self.client3])
130
+
131
+ def tearDown(self):
132
+ self.tee.close_all_clients()
133
+ del self.tee
134
+ super().tearDown()
135
+
136
+ def test_invalid_constructor(self):
137
+ with self.assertRaises(Exception) as context:
138
+ TranscriptionTeeClient([])
139
+
140
+ def test_multicast_unconditional(self):
141
+ self.tee.multicast_packet(self.mock_audio_packet, True)
142
+ for client in self.tee.clients:
143
+ client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
144
+
145
+ def test_multicast_conditional(self):
146
+ self.client2.recording = False
147
+ self.client3.recording = True
148
+ self.tee.multicast_packet(self.mock_audio_packet, False)
149
+ self.client2.client_socket.send.assert_not_called()
150
+ self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
151
+
152
+ def test_close_all(self):
153
+ self.tee.close_all_clients()
154
+ for client in self.tee.clients:
155
+ client.client_socket.close.assert_called()
156
+
157
+ def test_write_all_srt(self):
158
+ for client in self.tee.clients:
159
+ client.server_backend = "faster_whisper"
160
+ self.tee.write_all_clients_srt()
161
+ self.assertTrue(Path("transcript.srt").is_file())
162
+ self.assertTrue(Path("translation.srt").is_file())
tests/test_server.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+ import json
4
+ import unittest
5
+ from unittest import mock
6
+
7
+ import numpy as np
8
+ import jiwer
9
+
10
+ from websockets.exceptions import ConnectionClosed
11
+ from whisper_live.server import TranscriptionServer, BackendType, ClientManager
12
+ from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
13
+ from whisper.normalizers import EnglishTextNormalizer
14
+
15
+
16
+ class TestTranscriptionServerInitialization(unittest.TestCase):
17
+ def test_initialization(self):
18
+ server = TranscriptionServer()
19
+ server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
20
+ self.assertEqual(server.client_manager.max_clients, 4)
21
+ self.assertEqual(server.client_manager.max_connection_time, 600)
22
+ self.assertDictEqual(server.client_manager.clients, {})
23
+ self.assertDictEqual(server.client_manager.start_times, {})
24
+
25
+
26
+ class TestGetWaitTime(unittest.TestCase):
27
+ def setUp(self):
28
+ self.server = TranscriptionServer()
29
+ self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
30
+ self.server.client_manager.start_times = {
31
+ 'client1': time.time() - 120,
32
+ 'client2': time.time() - 300
33
+ }
34
+ self.server.client_manager.max_connection_time = 600
35
+
36
+ def test_get_wait_time(self):
37
+ expected_wait_time = (600 - (time.time() - self.server.client_manager.start_times['client2'])) / 60
38
+ print(self.server.client_manager.get_wait_time(), expected_wait_time)
39
+ self.assertAlmostEqual(self.server.client_manager.get_wait_time(), expected_wait_time, places=2)
40
+
41
+
42
+ class TestServerConnection(unittest.TestCase):
43
+ def setUp(self):
44
+ self.server = TranscriptionServer()
45
+
46
+ @mock.patch('websockets.WebSocketCommonProtocol')
47
+ def test_connection(self, mock_websocket):
48
+ mock_websocket.recv.return_value = json.dumps({
49
+ 'uid': 'test_client',
50
+ 'language': 'en',
51
+ 'task': 'transcribe',
52
+ 'model': 'tiny.en'
53
+ })
54
+ self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
55
+
56
+ @mock.patch('websockets.WebSocketCommonProtocol')
57
+ def test_recv_audio_exception_handling(self, mock_websocket):
58
+ mock_websocket.recv.side_effect = [json.dumps({
59
+ 'uid': 'test_client',
60
+ 'language': 'en',
61
+ 'task': 'transcribe',
62
+ 'model': 'tiny.en'
63
+ }), np.array([1, 2, 3]).tobytes()]
64
+
65
+ with self.assertLogs(level="ERROR"):
66
+ self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
67
+
68
+ self.assertNotIn(mock_websocket, self.server.client_manager.clients)
69
+
70
+
71
+ class TestServerInferenceAccuracy(unittest.TestCase):
72
+ @classmethod
73
+ def setUpClass(cls):
74
+ cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio')
75
+ cls.mock_pyaudio = cls.mock_pyaudio_patch.start()
76
+ cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock()
77
+
78
+ cls.server_process = subprocess.Popen(["python", "run_server.py"])
79
+ time.sleep(2)
80
+
81
+ @classmethod
82
+ def tearDownClass(cls):
83
+ cls.server_process.terminate()
84
+ cls.server_process.wait()
85
+
86
+ def setUp(self):
87
+ self.normalizer = EnglishTextNormalizer()
88
+
89
+ def check_prediction(self, srt_path):
90
+ gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
91
+ with open(srt_path, "r") as f:
92
+ lines = f.readlines()
93
+ prediction = " ".join([line.strip() for line in lines[2::4]])
94
+ prediction_normalized = self.normalizer(prediction)
95
+ gt_normalized = self.normalizer(gt)
96
+
97
+ # calculate WER
98
+ wer_score = jiwer.wer(gt_normalized, prediction_normalized)
99
+ self.assertLess(wer_score, 0.05)
100
+
101
+ def test_inference(self):
102
+ client = TranscriptionClient(
103
+ "localhost", "9090", model="base.en", lang="en",
104
+ )
105
+ client("assets/jfk.flac")
106
+ self.check_prediction("output.srt")
107
+
108
+ def test_simultaneous_inference(self):
109
+ client1 = Client(
110
+ "localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt")
111
+ client2 = Client(
112
+ "localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt")
113
+ tee = TranscriptionTeeClient([client1, client2])
114
+ tee("assets/jfk.flac")
115
+ self.check_prediction("transcript1.srt")
116
+ self.check_prediction("transcript2.srt")
117
+
118
+
119
+ class TestExceptionHandling(unittest.TestCase):
120
+ def setUp(self):
121
+ self.server = TranscriptionServer()
122
+
123
+ @mock.patch('websockets.WebSocketCommonProtocol')
124
+ def test_connection_closed_exception(self, mock_websocket):
125
+ mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())
126
+
127
+ with self.assertLogs(level="INFO") as log:
128
+ self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
129
+ self.assertTrue(any("Connection closed by client" in message for message in log.output))
130
+
131
+ @mock.patch('websockets.WebSocketCommonProtocol')
132
+ def test_json_decode_exception(self, mock_websocket):
133
+ mock_websocket.recv.return_value = "invalid json"
134
+
135
+ with self.assertLogs(level="ERROR") as log:
136
+ self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
137
+ self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))
138
+
139
+ @mock.patch('websockets.WebSocketCommonProtocol')
140
+ def test_unexpected_exception_handling(self, mock_websocket):
141
+ mock_websocket.recv.side_effect = RuntimeError("Unexpected error")
142
+
143
+ with self.assertLogs(level="ERROR") as log:
144
+ self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
145
+ for message in log.output:
146
+ print(message)
147
+ print()
148
+ self.assertTrue(any("Unexpected error" in message for message in log.output))
tests/test_vad.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from whisper_live.transcriber.tensorrt_utils import load_audio
4
+ from whisper_live.vad import VoiceActivityDetector
5
+
6
+
7
+ class TestVoiceActivityDetection(unittest.TestCase):
8
+ def setUp(self):
9
+ self.vad = VoiceActivityDetector()
10
+ self.sample_rate = 16000
11
+
12
+ def generate_silence(self, duration_seconds):
13
+ return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32)
14
+
15
+ def load_speech_segment(self, filepath):
16
+ return load_audio(filepath)
17
+
18
+ def test_vad_silence_detection(self):
19
+ silence = self.generate_silence(3)
20
+ is_speech_present = self.vad(silence.copy())
21
+ self.assertFalse(is_speech_present, "VAD incorrectly identified silence as speech.")
22
+
23
+ def test_vad_speech_detection(self):
24
+ audio_tensor = load_audio("assets/jfk.flac")
25
+ is_speech_present = self.vad(audio_tensor)
26
+ self.assertTrue(is_speech_present, "VAD failed to identify speech segment.")
whisper_live/__init__.py ADDED
File without changes
whisper_live/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.7.1"
whisper_live/backend/__init__.py ADDED
File without changes
whisper_live/backend/base.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import threading
4
+ import time
5
+ import numpy as np
6
+
7
+
8
+ class ServeClientBase(object):
9
+ RATE = 16000
10
+ SERVER_READY = "SERVER_READY"
11
+ DISCONNECT = "DISCONNECT"
12
+
13
+ client_uid: str
14
+ """A unique identifier for the client."""
15
+ websocket: object
16
+ """The WebSocket connection for the client."""
17
+ send_last_n_segments: int
18
+ """Number of most recent segments to send to the client."""
19
+ no_speech_thresh: float
20
+ """Segments with no speech probability above this threshold will be discarded."""
21
+ clip_audio: bool
22
+ """Whether to clip audio with no valid segments."""
23
+ same_output_threshold: int
24
+ """Number of repeated outputs before considering it as a valid segment."""
25
+
26
+ def __init__(
27
+ self,
28
+ client_uid,
29
+ websocket,
30
+ send_last_n_segments=10,
31
+ no_speech_thresh=0.45,
32
+ clip_audio=False,
33
+ same_output_threshold=10,
34
+ ):
35
+ self.client_uid = client_uid
36
+ self.websocket = websocket
37
+ self.send_last_n_segments = send_last_n_segments
38
+ self.no_speech_thresh = no_speech_thresh
39
+ self.clip_audio = clip_audio
40
+ self.same_output_threshold = same_output_threshold
41
+
42
+ self.frames = b""
43
+ self.timestamp_offset = 0.0
44
+ self.frames_np = None
45
+ self.frames_offset = 0.0
46
+ self.text = []
47
+ self.current_out = ""
48
+ self.prev_out = ""
49
+ self.exit = False
50
+ self.same_output_count = 0
51
+ self.transcript = []
52
+ self.end_time_for_same_output = None
53
+
54
+ # threading
55
+ self.lock = threading.Lock()
56
+
57
+ def speech_to_text(self):
58
+ """
59
+ Process an audio stream in an infinite loop, continuously transcribing the speech.
60
+
61
+ This method continuously receives audio frames, performs real-time transcription, and sends
62
+ transcribed segments to the client via a WebSocket connection.
63
+
64
+ If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
65
+ It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
66
+ are sent to the client in real-time, and a history of segments is maintained to provide context.
67
+
68
+ Raises:
69
+ Exception: If there is an issue with audio processing or WebSocket communication.
70
+
71
+ """
72
+ while True:
73
+ if self.exit:
74
+ logging.info("Exiting speech to text thread")
75
+ break
76
+
77
+ if self.frames_np is None:
78
+ continue
79
+
80
+ if self.clip_audio:
81
+ self.clip_audio_if_no_valid_segment()
82
+
83
+ input_bytes, duration = self.get_audio_chunk_for_processing()
84
+ if duration < 1.0:
85
+ time.sleep(0.1) # wait for audio chunks to arrive
86
+ continue
87
+ try:
88
+ input_sample = input_bytes.copy()
89
+ result = self.transcribe_audio(input_sample)
90
+
91
+ if result is None or self.language is None:
92
+ self.timestamp_offset += duration
93
+ time.sleep(0.25) # wait for voice activity, result is None when no voice activity
94
+ continue
95
+ self.handle_transcription_output(result, duration)
96
+
97
+ except Exception as e:
98
+ logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}")
99
+ time.sleep(0.01)
100
+
101
+ def transcribe_audio(self):
102
+ raise NotImplementedError
103
+
104
+ def handle_transcription_output(self, result, duration):
105
+ raise NotImplementedError
106
+
107
+ def format_segment(self, start, end, text, completed=False):
108
+ """
109
+ Formats a transcription segment with precise start and end times alongside the transcribed text.
110
+
111
+ Args:
112
+ start (float): The start time of the transcription segment in seconds.
113
+ end (float): The end time of the transcription segment in seconds.
114
+ text (str): The transcribed text corresponding to the segment.
115
+
116
+ Returns:
117
+ dict: A dictionary representing the formatted transcription segment, including
118
+ 'start' and 'end' times as strings with three decimal places and the 'text'
119
+ of the transcription.
120
+ """
121
+ return {
122
+ 'start': "{:.3f}".format(start),
123
+ 'end': "{:.3f}".format(end),
124
+ 'text': text,
125
+ 'completed': completed
126
+ }
127
+
128
+ def add_frames(self, frame_np):
129
+ """
130
+ Add audio frames to the ongoing audio stream buffer.
131
+
132
+ This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
133
+ of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
134
+ to prevent excessive memory usage.
135
+
136
+ If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
137
+ of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
138
+ audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
139
+
140
+ Args:
141
+ frame_np (numpy.ndarray): The audio frame data as a NumPy array.
142
+
143
+ """
144
+ self.lock.acquire()
145
+ if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
146
+ self.frames_offset += 30.0
147
+ self.frames_np = self.frames_np[int(30*self.RATE):]
148
+ # check timestamp offset(should be >= self.frame_offset)
149
+ # this basically means that there is no speech as timestamp offset hasnt updated
150
+ # and is less than frame_offset
151
+ if self.timestamp_offset < self.frames_offset:
152
+ self.timestamp_offset = self.frames_offset
153
+ if self.frames_np is None:
154
+ self.frames_np = frame_np.copy()
155
+ else:
156
+ self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
157
+ self.lock.release()
158
+
159
+ def clip_audio_if_no_valid_segment(self):
160
+ """
161
+ Update the timestamp offset based on audio buffer status.
162
+ Clip audio if the current chunk exceeds 30 seconds, this basically implies that
163
+ no valid segment for the last 30 seconds from whisper
164
+ """
165
+ with self.lock:
166
+ if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
167
+ duration = self.frames_np.shape[0] / self.RATE
168
+ self.timestamp_offset = self.frames_offset + duration - 5
169
+
170
+ def get_audio_chunk_for_processing(self):
171
+ """
172
+ Retrieves the next chunk of audio data for processing based on the current offsets.
173
+
174
+ Calculates which part of the audio data should be processed next, based on
175
+ the difference between the current timestamp offset and the frame's offset, scaled by
176
+ the audio sample rate (RATE). It then returns this chunk of audio data along with its
177
+ duration in seconds.
178
+
179
+ Returns:
180
+ tuple: A tuple containing:
181
+ - input_bytes (np.ndarray): The next chunk of audio data to be processed.
182
+ - duration (float): The duration of the audio chunk in seconds.
183
+ """
184
+ with self.lock:
185
+ samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
186
+ input_bytes = self.frames_np[int(samples_take):].copy()
187
+ duration = input_bytes.shape[0] / self.RATE
188
+ return input_bytes, duration
189
+
190
+ def prepare_segments(self, last_segment=None):
191
+ """
192
+ Prepares the segments of transcribed text to be sent to the client.
193
+
194
+ This method compiles the recent segments of transcribed text, ensuring that only the
195
+ specified number of the most recent segments are included. It also appends the most
196
+ recent segment of text if provided (which is considered incomplete because of the possibility
197
+ of the last word being truncated in the audio chunk).
198
+
199
+ Args:
200
+ last_segment (str, optional): The most recent segment of transcribed text to be added
201
+ to the list of segments. Defaults to None.
202
+
203
+ Returns:
204
+ list: A list of transcribed text segments to be sent to the client.
205
+ """
206
+ segments = []
207
+ if len(self.transcript) >= self.send_last_n_segments:
208
+ segments = self.transcript[-self.send_last_n_segments:].copy()
209
+ else:
210
+ segments = self.transcript.copy()
211
+ if last_segment is not None:
212
+ segments = segments + [last_segment]
213
+ return segments
214
+
215
+ def get_audio_chunk_duration(self, input_bytes):
216
+ """
217
+ Calculates the duration of the provided audio chunk.
218
+
219
+ Args:
220
+ input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
221
+
222
+ Returns:
223
+ float: The duration of the audio chunk in seconds.
224
+ """
225
+ return input_bytes.shape[0] / self.RATE
226
+
227
+ def send_transcription_to_client(self, segments):
228
+ """
229
+ Sends the specified transcription segments to the client over the websocket connection.
230
+
231
+ This method formats the transcription segments into a JSON object and attempts to send
232
+ this object to the client. If an error occurs during the send operation, it logs the error.
233
+
234
+ Returns:
235
+ segments (list): A list of transcription segments to be sent to the client.
236
+ """
237
+ try:
238
+ self.websocket.send(
239
+ json.dumps({
240
+ "uid": self.client_uid,
241
+ "segments": segments,
242
+ })
243
+ )
244
+ except Exception as e:
245
+ logging.error(f"[ERROR]: Sending data to client: {e}")
246
+
247
+ def disconnect(self):
248
+ """
249
+ Notify the client of disconnection and send a disconnect message.
250
+
251
+ This method sends a disconnect message to the client via the WebSocket connection to notify them
252
+ that the transcription service is disconnecting gracefully.
253
+
254
+ """
255
+ self.websocket.send(json.dumps({
256
+ "uid": self.client_uid,
257
+ "message": self.DISCONNECT
258
+ }))
259
+
260
+ def cleanup(self):
261
+ """
262
+ Perform cleanup tasks before exiting the transcription service.
263
+
264
+ This method performs necessary cleanup tasks, including stopping the transcription thread, marking
265
+ the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
266
+ associated with the transcription process.
267
+
268
+ """
269
+ logging.info("Cleaning up.")
270
+ self.exit = True
271
+
272
+ def get_segment_no_speech_prob(self, segment):
273
+ return getattr(segment, "no_speech_prob", 0)
274
+
275
+ def get_segment_start(self, segment):
276
+ return getattr(segment, "start", getattr(segment, "start_ts", 0))
277
+
278
+ def get_segment_end(self, segment):
279
+ return getattr(segment, "end", getattr(segment, "end_ts", 0))
280
+
281
+ def update_segments(self, segments, duration):
282
+ """
283
+ Processes the segments from Whisper and updates the transcript.
284
+ Uses helper methods to account for differences between backends.
285
+
286
+ Args:
287
+ segments (list): List of segments returned by the transcriber.
288
+ duration (float): Duration of the current audio chunk.
289
+
290
+ Returns:
291
+ dict or None: The last processed segment (if any).
292
+ """
293
+ offset = None
294
+ self.current_out = ''
295
+ last_segment = None
296
+
297
+ # Process complete segments only if there are more than one
298
+ # and if the last segment's no_speech_prob is below the threshold.
299
+ if len(segments) > 1 and self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh:
300
+ for s in segments[:-1]:
301
+ text_ = s.text
302
+ self.text.append(text_)
303
+ with self.lock:
304
+ start = self.timestamp_offset + self.get_segment_start(s)
305
+ end = self.timestamp_offset + min(duration, self.get_segment_end(s))
306
+ if start >= end:
307
+ continue
308
+ if self.get_segment_no_speech_prob(s) > self.no_speech_thresh:
309
+ continue
310
+ self.transcript.append(self.format_segment(start, end, text_, completed=True))
311
+ offset = min(duration, self.get_segment_end(s))
312
+
313
+ # Process the last segment if its no_speech_prob is acceptable.
314
+ if self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh:
315
+ self.current_out += segments[-1].text
316
+ with self.lock:
317
+ last_segment = self.format_segment(
318
+ self.timestamp_offset + self.get_segment_start(segments[-1]),
319
+ self.timestamp_offset + min(duration, self.get_segment_end(segments[-1])),
320
+ self.current_out,
321
+ completed=False
322
+ )
323
+
324
+ # Handle repeated output logic.
325
+ if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
326
+ self.same_output_count += 1
327
+
328
+ # if we remove the audio because of same output on the nth reptition we might remove the
329
+ # audio thats not yet transcribed so, capturing the time when it was repeated for the first time
330
+ if self.end_time_for_same_output is None:
331
+ self.end_time_for_same_output = self.get_segment_end(segments[-1])
332
+ time.sleep(0.1) # wait briefly for any new voice activity
333
+ else:
334
+ self.same_output_count = 0
335
+ self.end_time_for_same_output = None
336
+
337
+ # If the same incomplete segment is repeated too many times,
338
+ # append it to the transcript and update the offset.
339
+ if self.same_output_count > self.same_output_threshold:
340
+ if not self.text or self.text[-1].strip().lower() != self.current_out.strip().lower():
341
+ self.text.append(self.current_out)
342
+ with self.lock:
343
+ self.transcript.append(self.format_segment(
344
+ self.timestamp_offset,
345
+ self.timestamp_offset + min(duration, self.end_time_for_same_output),
346
+ self.current_out,
347
+ completed=True
348
+ ))
349
+ self.current_out = ''
350
+ offset = min(duration, self.end_time_for_same_output)
351
+ self.same_output_count = 0
352
+ last_segment = None
353
+ self.end_time_for_same_output = None
354
+ else:
355
+ self.prev_out = self.current_out
356
+
357
+ if offset is not None:
358
+ with self.lock:
359
+ self.timestamp_offset += offset
360
+
361
+ return last_segment
whisper_live/backend/faster_whisper_backend.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import threading
4
+ import time
5
+ import torch
6
+
7
+ from whisper_live.transcriber.transcriber_faster_whisper import WhisperModel
8
+ from whisper_live.backend.base import ServeClientBase
9
+
10
+
11
+ class ServeClientFasterWhisper(ServeClientBase):
12
+ SINGLE_MODEL = None
13
+ SINGLE_MODEL_LOCK = threading.Lock()
14
+
15
+ def __init__(
16
+ self,
17
+ websocket,
18
+ task="transcribe",
19
+ device=None,
20
+ language=None,
21
+ client_uid=None,
22
+ model="small.en",
23
+ initial_prompt=None,
24
+ vad_parameters=None,
25
+ use_vad=True,
26
+ single_model=False,
27
+ send_last_n_segments=10,
28
+ no_speech_thresh=0.45,
29
+ clip_audio=False,
30
+ same_output_threshold=10,
31
+ ):
32
+ """
33
+ Initialize a ServeClient instance.
34
+ The Whisper model is initialized based on the client's language and device availability.
35
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
36
+ to the client to indicate that the server is ready.
37
+
38
+ Args:
39
+ websocket (WebSocket): The WebSocket connection for the client.
40
+ task (str, optional): The task type, e.g., "transcribe". Defaults to "transcribe".
41
+ device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
42
+ language (str, optional): The language for transcription. Defaults to None.
43
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
44
+ model (str, optional): The whisper model size. Defaults to 'small.en'
45
+ initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
46
+ single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
47
+ send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
48
+ no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
49
+ clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
50
+ same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
51
+
52
+ """
53
+ super().__init__(
54
+ client_uid,
55
+ websocket,
56
+ send_last_n_segments,
57
+ no_speech_thresh,
58
+ clip_audio,
59
+ same_output_threshold,
60
+ )
61
+ self.model_sizes = [
62
+ "tiny", "tiny.en", "base", "base.en", "small", "small.en",
63
+ "medium", "medium.en", "large-v2", "large-v3", "distil-small.en",
64
+ "distil-medium.en", "distil-large-v2", "distil-large-v3",
65
+ "large-v3-turbo", "turbo"
66
+ ]
67
+
68
+ self.model_size_or_path = model
69
+ self.language = "en" if self.model_size_or_path.endswith("en") else language
70
+ self.task = task
71
+ self.initial_prompt = initial_prompt
72
+ self.vad_parameters = vad_parameters or {"onset": 0.5}
73
+
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+ if device == "cuda":
76
+ major, _ = torch.cuda.get_device_capability(device)
77
+ self.compute_type = "float16" if major >= 7 else "float32"
78
+ else:
79
+ self.compute_type = "int8"
80
+
81
+ if self.model_size_or_path is None:
82
+ return
83
+ logging.info(f"Using Device={device} with precision {self.compute_type}")
84
+
85
+ try:
86
+ if single_model:
87
+ if ServeClientFasterWhisper.SINGLE_MODEL is None:
88
+ self.create_model(device)
89
+ ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
90
+ else:
91
+ self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
92
+ else:
93
+ self.create_model(device)
94
+ except Exception as e:
95
+ logging.error(f"Failed to load model: {e}")
96
+ self.websocket.send(json.dumps({
97
+ "uid": self.client_uid,
98
+ "status": "ERROR",
99
+ "message": f"Failed to load model: {str(self.model_size_or_path)}"
100
+ }))
101
+ self.websocket.close()
102
+ return
103
+
104
+ self.use_vad = use_vad
105
+
106
+ # threading
107
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
108
+ self.trans_thread.start()
109
+ self.websocket.send(
110
+ json.dumps(
111
+ {
112
+ "uid": self.client_uid,
113
+ "message": self.SERVER_READY,
114
+ "backend": "faster_whisper"
115
+ }
116
+ )
117
+ )
118
+
119
+ def create_model(self, device):
120
+ """
121
+ Instantiates a new model, sets it as the transcriber.
122
+ """
123
+ self.transcriber = WhisperModel(
124
+ self.model_size_or_path,
125
+ device=device,
126
+ compute_type=self.compute_type,
127
+ local_files_only=False,
128
+ )
129
+
130
+ def check_valid_model(self, model_size):
131
+ """
132
+ Check if it's a valid whisper model size.
133
+
134
+ Args:
135
+ model_size (str): The name of the model size to check.
136
+
137
+ Returns:
138
+ str: The model size if valid, None otherwise.
139
+ """
140
+ if model_size not in self.model_sizes:
141
+ self.websocket.send(
142
+ json.dumps(
143
+ {
144
+ "uid": self.client_uid,
145
+ "status": "ERROR",
146
+ "message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}"
147
+ }
148
+ )
149
+ )
150
+ return None
151
+ return model_size
152
+
153
+ def set_language(self, info):
154
+ """
155
+ Updates the language attribute based on the detected language information.
156
+
157
+ Args:
158
+ info (object): An object containing the detected language and its probability. This object
159
+ must have at least two attributes: `language`, a string indicating the detected
160
+ language, and `language_probability`, a float representing the confidence level
161
+ of the language detection.
162
+ """
163
+ if info.language_probability > 0.5:
164
+ self.language = info.language
165
+ logging.info(f"Detected language {self.language} with probability {info.language_probability}")
166
+ self.websocket.send(json.dumps(
167
+ {"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
168
+
169
+ def transcribe_audio(self, input_sample):
170
+ """
171
+ Transcribes the provided audio sample using the configured transcriber instance.
172
+
173
+ If the language has not been set, it updates the session's language based on the transcription
174
+ information.
175
+
176
+ Args:
177
+ input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
178
+ array representing the audio data.
179
+
180
+ Returns:
181
+ The transcription result from the transcriber. The exact format of this result
182
+ depends on the implementation of the `transcriber.transcribe` method but typically
183
+ includes the transcribed text.
184
+ """
185
+ if ServeClientFasterWhisper.SINGLE_MODEL:
186
+ ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
187
+ result, info = self.transcriber.transcribe(
188
+ input_sample,
189
+ initial_prompt=self.initial_prompt,
190
+ language=self.language,
191
+ task=self.task,
192
+ vad_filter=self.use_vad,
193
+ vad_parameters=self.vad_parameters if self.use_vad else None)
194
+ if ServeClientFasterWhisper.SINGLE_MODEL:
195
+ ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()
196
+
197
+ if self.language is None and info is not None:
198
+ self.set_language(info)
199
+ return result
200
+
201
+ def handle_transcription_output(self, result, duration):
202
+ """
203
+ Handle the transcription output, updating the transcript and sending data to the client.
204
+
205
+ Args:
206
+ result (str): The result from whisper inference i.e. the list of segments.
207
+ duration (float): Duration of the transcribed audio chunk.
208
+ """
209
+ segments = []
210
+ if len(result):
211
+ self.t_start = None
212
+ last_segment = self.update_segments(result, duration)
213
+ segments = self.prepare_segments(last_segment)
214
+
215
+ if len(segments):
216
+ self.send_transcription_to_client(segments)
whisper_live/backend/openvino_backend.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import threading
4
+ import time
5
+
6
+ from openvino import Core
7
+ from whisper_live.backend.base import ServeClientBase
8
+ from whisper_live.transcriber.transcriber_openvino import WhisperOpenVINO
9
+
10
+
11
+ class ServeClientOpenVINO(ServeClientBase):
12
+ SINGLE_MODEL = None
13
+ SINGLE_MODEL_LOCK = threading.Lock()
14
+
15
+ def __init__(
16
+ self,
17
+ websocket,
18
+ task="transcribe",
19
+ device=None,
20
+ language=None,
21
+ client_uid=None,
22
+ model="small.en",
23
+ initial_prompt=None,
24
+ vad_parameters=None,
25
+ use_vad=True,
26
+ single_model=False,
27
+ send_last_n_segments=10,
28
+ no_speech_thresh=0.45,
29
+ clip_audio=False,
30
+ same_output_threshold=10,
31
+ ):
32
+ """
33
+ Initialize a ServeClient instance.
34
+ The Whisper model is initialized based on the client's language and device availability.
35
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
36
+ to the client to indicate that the server is ready.
37
+
38
+ Args:
39
+ websocket (WebSocket): The WebSocket connection for the client.
40
+ task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
41
+ device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
42
+ language (str, optional): The language for transcription. Defaults to None.
43
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
44
+ model (str, optional): Huggingface model_id for a valid OpenVINO model.
45
+ initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
46
+ single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
47
+ send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
48
+ no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
49
+ clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
50
+ same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
51
+ """
52
+ super().__init__(
53
+ client_uid,
54
+ websocket,
55
+ send_last_n_segments,
56
+ no_speech_thresh,
57
+ clip_audio,
58
+ same_output_threshold,
59
+ )
60
+ self.language = "en" if language is None else language
61
+ if not self.language.startswith("<|"):
62
+ self.language = f"<|{self.language}|>"
63
+
64
+ self.task = "transcribe" if task is None else task
65
+
66
+ self.clip_audio = True
67
+
68
+ core = Core()
69
+ available_devices = core.available_devices
70
+ if 'GPU' in available_devices:
71
+ selected_device = 'GPU'
72
+ else:
73
+ gpu_devices = [d for d in available_devices if d.startswith('GPU')]
74
+ selected_device = gpu_devices[0] if gpu_devices else 'CPU'
75
+ self.device = selected_device
76
+
77
+
78
+ if single_model:
79
+ if ServeClientOpenVINO.SINGLE_MODEL is None:
80
+ self.create_model(model)
81
+ ServeClientOpenVINO.SINGLE_MODEL = self.transcriber
82
+ else:
83
+ self.transcriber = ServeClientOpenVINO.SINGLE_MODEL
84
+ else:
85
+ self.create_model(model)
86
+
87
+ # threading
88
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
89
+ self.trans_thread.start()
90
+
91
+ self.websocket.send(json.dumps({
92
+ "uid": self.client_uid,
93
+ "message": self.SERVER_READY,
94
+ "backend": "openvino"
95
+ }))
96
+ logging.info(f"Using OpenVINO device: {self.device}")
97
+ logging.info(f"Running OpenVINO backend with language: {self.language} and task: {self.task}")
98
+
99
+ def create_model(self, model_id):
100
+ """
101
+ Instantiates a new model, sets it as the transcriber.
102
+ """
103
+ self.transcriber = WhisperOpenVINO(
104
+ model_id,
105
+ device=self.device,
106
+ language=self.language,
107
+ task=self.task
108
+ )
109
+
110
+ def transcribe_audio(self, input_sample):
111
+ """
112
+ Transcribes the provided audio sample using the configured transcriber instance.
113
+
114
+ If the language has not been set, it updates the session's language based on the transcription
115
+ information.
116
+
117
+ Args:
118
+ input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
119
+ array representing the audio data.
120
+
121
+ Returns:
122
+ The transcription result from the transcriber. The exact format of this result
123
+ depends on the implementation of the `transcriber.transcribe` method but typically
124
+ includes the transcribed text.
125
+ """
126
+ if ServeClientOpenVINO.SINGLE_MODEL:
127
+ ServeClientOpenVINO.SINGLE_MODEL_LOCK.acquire()
128
+ result = self.transcriber.transcribe(input_sample)
129
+ if ServeClientOpenVINO.SINGLE_MODEL:
130
+ ServeClientOpenVINO.SINGLE_MODEL_LOCK.release()
131
+ return result
132
+
133
+ def handle_transcription_output(self, result, duration):
134
+ """
135
+ Handle the transcription output, updating the transcript and sending data to the client.
136
+
137
+ Args:
138
+ result (str): The result from whisper inference i.e. the list of segments.
139
+ duration (float): Duration of the transcribed audio chunk.
140
+ """
141
+ segments = []
142
+ if len(result):
143
+ self.t_start = None
144
+ last_segment = self.update_segments(result, duration)
145
+ segments = self.prepare_segments(last_segment)
146
+
147
+ if len(segments):
148
+ self.send_transcription_to_client(segments)
whisper_live/backend/trt_backend.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import threading
4
+ import time
5
+
6
+ from whisper_live.backend.base import ServeClientBase
7
+ from whisper_live.transcriber.transcriber_tensorrt import WhisperTRTLLM
8
+
9
+
10
+ class ServeClientTensorRT(ServeClientBase):
11
+ SINGLE_MODEL = None
12
+ SINGLE_MODEL_LOCK = threading.Lock()
13
+
14
+ def __init__(
15
+ self,
16
+ websocket,
17
+ task="transcribe",
18
+ multilingual=False,
19
+ language=None,
20
+ client_uid=None,
21
+ model=None,
22
+ single_model=False,
23
+ use_py_session=False,
24
+ max_new_tokens=225,
25
+ send_last_n_segments=10,
26
+ no_speech_thresh=0.45,
27
+ clip_audio=False,
28
+ same_output_threshold=10,
29
+ ):
30
+ """
31
+ Initialize a ServeClient instance.
32
+ The Whisper model is initialized based on the client's language and device availability.
33
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
34
+ to the client to indicate that the server is ready.
35
+
36
+ Args:
37
+ websocket (WebSocket): The WebSocket connection for the client.
38
+ task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
39
+ device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
40
+ multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
41
+ language (str, optional): The language for transcription. Defaults to None.
42
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
43
+ single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
44
+ use_py_session (bool, optional): Use python session or cpp session. Defaults to Cpp Session.
45
+ max_new_tokens (int, optional): Max number of tokens to generate.
46
+ send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
47
+ no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
48
+ clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
49
+ same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
50
+ """
51
+ super().__init__(
52
+ client_uid,
53
+ websocket,
54
+ send_last_n_segments,
55
+ no_speech_thresh,
56
+ clip_audio,
57
+ same_output_threshold,
58
+ )
59
+
60
+ self.language = language if multilingual else "en"
61
+ self.task = task
62
+ self.eos = False
63
+ self.max_new_tokens = max_new_tokens
64
+
65
+ if single_model:
66
+ if ServeClientTensorRT.SINGLE_MODEL is None:
67
+ self.create_model(model, multilingual, use_py_session=use_py_session)
68
+ ServeClientTensorRT.SINGLE_MODEL = self.transcriber
69
+ else:
70
+ self.transcriber = ServeClientTensorRT.SINGLE_MODEL
71
+ else:
72
+ self.create_model(model, multilingual, use_py_session=use_py_session)
73
+
74
+ # threading
75
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
76
+ self.trans_thread.start()
77
+
78
+ self.websocket.send(json.dumps({
79
+ "uid": self.client_uid,
80
+ "message": self.SERVER_READY,
81
+ "backend": "tensorrt"
82
+ }))
83
+
84
+ def create_model(self, model, multilingual, warmup=True, use_py_session=False):
85
+ """
86
+ Instantiates a new model, sets it as the transcriber and does warmup if desired.
87
+ """
88
+ self.transcriber = WhisperTRTLLM(
89
+ model,
90
+ assets_dir="assets",
91
+ device="cuda",
92
+ is_multilingual=multilingual,
93
+ language=self.language,
94
+ task=self.task,
95
+ use_py_session=use_py_session,
96
+ max_output_len=self.max_new_tokens,
97
+ )
98
+ if warmup:
99
+ self.warmup()
100
+
101
+ def warmup(self, warmup_steps=10):
102
+ """
103
+ Warmup TensorRT since first few inferences are slow.
104
+
105
+ Args:
106
+ warmup_steps (int): Number of steps to warm up the model for.
107
+ """
108
+ logging.info("[INFO:] Warming up TensorRT engine..")
109
+ mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac")
110
+ for i in range(warmup_steps):
111
+ self.transcriber.transcribe(mel)
112
+
113
+ def set_eos(self, eos):
114
+ """
115
+ Sets the End of Speech (EOS) flag.
116
+
117
+ Args:
118
+ eos (bool): The value to set for the EOS flag.
119
+ """
120
+ self.lock.acquire()
121
+ self.eos = eos
122
+ self.lock.release()
123
+
124
+ def handle_transcription_output(self, last_segment, duration):
125
+ """
126
+ Handle the transcription output, updating the transcript and sending data to the client.
127
+
128
+ Args:
129
+ last_segment (str): The last segment from the whisper output which is considered to be incomplete because
130
+ of the possibility of word being truncated.
131
+ duration (float): Duration of the transcribed audio chunk.
132
+ """
133
+ segments = self.prepare_segments({"text": last_segment})
134
+ self.send_transcription_to_client(segments)
135
+ if self.eos:
136
+ self.update_timestamp_offset(last_segment, duration)
137
+
138
+ def transcribe_audio(self, input_bytes):
139
+ """
140
+ Transcribe the audio chunk and send the results to the client.
141
+
142
+ Args:
143
+ input_bytes (np.array): The audio chunk to transcribe.
144
+ """
145
+ if ServeClientTensorRT.SINGLE_MODEL:
146
+ ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
147
+ logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
148
+ mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
149
+ last_segment = self.transcriber.transcribe(
150
+ mel,
151
+ text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>",
152
+ )
153
+ if ServeClientTensorRT.SINGLE_MODEL:
154
+ ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
155
+ if last_segment:
156
+ self.handle_transcription_output(last_segment, duration)
157
+
158
+ def update_timestamp_offset(self, last_segment, duration):
159
+ """
160
+ Update timestamp offset and transcript.
161
+
162
+ Args:
163
+ last_segment (str): Last transcribed audio from the whisper model.
164
+ duration (float): Duration of the last audio chunk.
165
+ """
166
+ if not len(self.transcript):
167
+ self.transcript.append({"text": last_segment + " "})
168
+ elif self.transcript[-1]["text"].strip() != last_segment:
169
+ self.transcript.append({"text": last_segment + " "})
170
+
171
+ with self.lock:
172
+ self.timestamp_offset += duration
173
+
174
+ def speech_to_text(self):
175
+ """
176
+ Process an audio stream in an infinite loop, continuously transcribing the speech.
177
+
178
+ This method continuously receives audio frames, performs real-time transcription, and sends
179
+ transcribed segments to the client via a WebSocket connection.
180
+
181
+ If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
182
+ It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
183
+ are sent to the client in real-time, and a history of segments is maintained to provide context.
184
+
185
+ Raises:
186
+ Exception: If there is an issue with audio processing or WebSocket communication.
187
+
188
+ """
189
+ while True:
190
+ if self.exit:
191
+ logging.info("Exiting speech to text thread")
192
+ break
193
+
194
+ if self.frames_np is None:
195
+ time.sleep(0.02) # wait for any audio to arrive
196
+ continue
197
+
198
+ self.clip_audio_if_no_valid_segment()
199
+
200
+ input_bytes, duration = self.get_audio_chunk_for_processing()
201
+ if duration < 0.4:
202
+ continue
203
+
204
+ try:
205
+ input_sample = input_bytes.copy()
206
+ logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}")
207
+ self.transcribe_audio(input_sample)
208
+
209
+ except Exception as e:
210
+ logging.error(f"[ERROR]: {e}")
whisper_live/client.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import wave
4
+
5
+ import logging
6
+ import numpy as np
7
+ import pyaudio
8
+ import threading
9
+ import json
10
+ import websocket
11
+ import uuid
12
+ import time
13
+ import av
14
+ import whisper_live.utils as utils
15
+
16
+
17
+ class Client:
18
+ """
19
+ Handles communication with a server using WebSocket.
20
+ """
21
+ INSTANCES = {}
22
+ END_OF_AUDIO = "END_OF_AUDIO"
23
+
24
+ def __init__(
25
+ self,
26
+ host=None,
27
+ port=None,
28
+ lang=None,
29
+ translate=False,
30
+ model="small",
31
+ srt_file_path="output.srt",
32
+ use_vad=True,
33
+ use_wss=False,
34
+ log_transcription=True,
35
+ max_clients=4,
36
+ max_connection_time=600,
37
+ send_last_n_segments=10,
38
+ no_speech_thresh=0.45,
39
+ clip_audio=False,
40
+ same_output_threshold=10,
41
+ transcription_callback=None,
42
+ ):
43
+ """
44
+ Initializes a Client instance for audio recording and streaming to a server.
45
+
46
+ If host and port are not provided, the WebSocket connection will not be established.
47
+ When translate is True, the task will be set to "translate" instead of "transcribe".
48
+ he audio recording starts immediately upon initialization.
49
+
50
+ Args:
51
+ host (str): The hostname or IP address of the server.
52
+ port (int): The port number for the WebSocket server.
53
+ lang (str, optional): The selected language for transcription. Default is None.
54
+ translate (bool, optional): Specifies if the task is translation. Default is False.
55
+ model (str, optional): The whisper model to use (e.g., "small", "medium", "large"). Default is "small".
56
+ srt_file_path (str, optional): The file path to save the output SRT file. Default is "output.srt".
57
+ use_vad (bool, optional): Whether to enable voice activity detection. Default is True.
58
+ log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
59
+ max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
60
+ max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
61
+ send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
62
+ no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
63
+ clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
64
+ same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
65
+ transcription_callback (callable, optional): A callback function to handle transcription results. Default is None.
66
+ """
67
+ self.recording = False
68
+ self.task = "transcribe"
69
+ self.uid = str(uuid.uuid4())
70
+ self.waiting = False
71
+ self.last_response_received = None
72
+ self.disconnect_if_no_response_for = 15
73
+ self.language = lang
74
+ self.model = model
75
+ self.server_error = False
76
+ self.srt_file_path = srt_file_path
77
+ self.use_vad = use_vad
78
+ self.use_wss = use_wss
79
+ self.last_segment = None
80
+ self.last_received_segment = None
81
+ self.log_transcription = log_transcription
82
+ self.max_clients = max_clients
83
+ self.max_connection_time = max_connection_time
84
+ self.send_last_n_segments = send_last_n_segments
85
+ self.no_speech_thresh = no_speech_thresh
86
+ self.clip_audio = clip_audio
87
+ self.same_output_threshold = same_output_threshold
88
+ self.transcription_callback = transcription_callback
89
+
90
+ if translate:
91
+ self.task = "translate"
92
+
93
+ self.audio_bytes = None
94
+
95
+ if host is not None and port is not None:
96
+ socket_protocol = 'wss' if self.use_wss else "ws"
97
+ socket_url = f"{socket_protocol}://{host}:{port}"
98
+ self.client_socket = websocket.WebSocketApp(
99
+ socket_url,
100
+ on_open=lambda ws: self.on_open(ws),
101
+ on_message=lambda ws, message: self.on_message(ws, message),
102
+ on_error=lambda ws, error: self.on_error(ws, error),
103
+ on_close=lambda ws, close_status_code, close_msg: self.on_close(
104
+ ws, close_status_code, close_msg
105
+ ),
106
+ )
107
+ else:
108
+ print("[ERROR]: No host or port specified.")
109
+ return
110
+
111
+ Client.INSTANCES[self.uid] = self
112
+
113
+ # start websocket client in a thread
114
+ self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
115
+ self.ws_thread.daemon = True
116
+ self.ws_thread.start()
117
+
118
+ self.transcript = []
119
+ print("[INFO]: * recording")
120
+
121
+ def handle_status_messages(self, message_data):
122
+ """Handles server status messages."""
123
+ status = message_data["status"]
124
+ if status == "WAIT":
125
+ self.waiting = True
126
+ print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
127
+ elif status == "ERROR":
128
+ print(f"Message from Server: {message_data['message']}")
129
+ self.server_error = True
130
+ elif status == "WARNING":
131
+ print(f"Message from Server: {message_data['message']}")
132
+
133
+ def process_segments(self, segments):
134
+ """Processes transcript segments."""
135
+ text = []
136
+ for i, seg in enumerate(segments):
137
+ if not text or text[-1] != seg["text"]:
138
+ text.append(seg["text"])
139
+ if i == len(segments) - 1 and not seg.get("completed", False):
140
+ self.last_segment = seg
141
+ elif (self.server_backend == "faster_whisper" and seg.get("completed", False) and
142
+ (not self.transcript or
143
+ float(seg['start']) >= float(self.transcript[-1]['end']))):
144
+ self.transcript.append(seg)
145
+ # update last received segment and last valid response time
146
+ if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
147
+ self.last_response_received = time.time()
148
+ self.last_received_segment = segments[-1]["text"]
149
+
150
+ # call the transcription callback if provided
151
+ if self.transcription_callback and callable(self.transcription_callback):
152
+ try:
153
+ self.transcription_callback(" ".join(text), segments) # string, list
154
+ except Exception as e:
155
+ print(f"[WARN] transcription_callback raised: {e}")
156
+ return
157
+
158
+ if self.log_transcription:
159
+ # Truncate to last 3 entries for brevity.
160
+ text = text[-3:]
161
+ utils.clear_screen()
162
+ utils.print_transcript(text)
163
+
164
+ def on_message(self, ws, message):
165
+ """
166
+ Callback function called when a message is received from the server.
167
+
168
+ It updates various attributes of the client based on the received message, including
169
+ recording status, language detection, and server messages. If a disconnect message
170
+ is received, it sets the recording status to False.
171
+
172
+ Args:
173
+ ws (websocket.WebSocketApp): The WebSocket client instance.
174
+ message (str): The received message from the server.
175
+
176
+ """
177
+ message = json.loads(message)
178
+
179
+ if self.uid != message.get("uid"):
180
+ print("[ERROR]: invalid client uid")
181
+ return
182
+
183
+ if "status" in message.keys():
184
+ self.handle_status_messages(message)
185
+ return
186
+
187
+ if "message" in message.keys() and message["message"] == "DISCONNECT":
188
+ print("[INFO]: Server disconnected due to overtime.")
189
+ self.recording = False
190
+
191
+ if "message" in message.keys() and message["message"] == "SERVER_READY":
192
+ self.last_response_received = time.time()
193
+ self.recording = True
194
+ self.server_backend = message["backend"]
195
+ print(f"[INFO]: Server Running with backend {self.server_backend}")
196
+ return
197
+
198
+ if "language" in message.keys():
199
+ self.language = message.get("language")
200
+ lang_prob = message.get("language_prob")
201
+ print(
202
+ f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
203
+ )
204
+ return
205
+
206
+ if "segments" in message.keys():
207
+ self.process_segments(message["segments"])
208
+
209
+ def on_error(self, ws, error):
210
+ print(f"[ERROR] WebSocket Error: {error}")
211
+ self.server_error = True
212
+ self.error_message = error
213
+
214
+ def on_close(self, ws, close_status_code, close_msg):
215
+ print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
216
+ self.recording = False
217
+ self.waiting = False
218
+
219
+ def on_open(self, ws):
220
+ """
221
+ Callback function called when the WebSocket connection is successfully opened.
222
+
223
+ Sends an initial configuration message to the server, including client UID,
224
+ language selection, and task type.
225
+
226
+ Args:
227
+ ws (websocket.WebSocketApp): The WebSocket client instance.
228
+
229
+ """
230
+ print("[INFO]: Opened connection")
231
+ ws.send(
232
+ json.dumps(
233
+ {
234
+ "uid": self.uid,
235
+ "language": self.language,
236
+ "task": self.task,
237
+ "model": self.model,
238
+ "use_vad": self.use_vad,
239
+ "max_clients": self.max_clients,
240
+ "max_connection_time": self.max_connection_time,
241
+ "send_last_n_segments": self.send_last_n_segments,
242
+ "no_speech_thresh": self.no_speech_thresh,
243
+ "clip_audio": self.clip_audio,
244
+ "same_output_threshold": self.same_output_threshold,
245
+ }
246
+ )
247
+ )
248
+
249
+ def send_packet_to_server(self, message):
250
+ """
251
+ Send an audio packet to the server using WebSocket.
252
+
253
+ Args:
254
+ message (bytes): The audio data packet in bytes to be sent to the server.
255
+
256
+ """
257
+ try:
258
+ self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
259
+ except Exception as e:
260
+ print(e)
261
+
262
+ def close_websocket(self):
263
+ """
264
+ Close the WebSocket connection and join the WebSocket thread.
265
+
266
+ First attempts to close the WebSocket connection using `self.client_socket.close()`. After
267
+ closing the connection, it joins the WebSocket thread to ensure proper termination.
268
+
269
+ """
270
+ try:
271
+ self.client_socket.close()
272
+ except Exception as e:
273
+ print("[ERROR]: Error closing WebSocket:", e)
274
+
275
+ try:
276
+ self.ws_thread.join()
277
+ except Exception as e:
278
+ print("[ERROR:] Error joining WebSocket thread:", e)
279
+
280
+ def get_client_socket(self):
281
+ """
282
+ Get the WebSocket client socket instance.
283
+
284
+ Returns:
285
+ WebSocketApp: The WebSocket client socket instance currently in use by the client.
286
+ """
287
+ return self.client_socket
288
+
289
+ def write_srt_file(self, output_path="output.srt"):
290
+ """
291
+ Writes out the transcript in .srt format.
292
+
293
+ Args:
294
+ message (output_path, optional): The path to the target file. Default is "output.srt".
295
+
296
+ """
297
+ if self.server_backend == "faster_whisper":
298
+ if not self.transcript and self.last_segment is not None:
299
+ self.transcript.append(self.last_segment)
300
+ elif self.last_segment and self.transcript[-1]["text"] != self.last_segment["text"]:
301
+ self.transcript.append(self.last_segment)
302
+ utils.create_srt_file(self.transcript, output_path)
303
+
304
+ def wait_before_disconnect(self):
305
+ """Waits a bit before disconnecting in order to process pending responses."""
306
+ assert self.last_response_received
307
+ while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
308
+ continue
309
+
310
+
311
+ class TranscriptionTeeClient:
312
+ """
313
+ Client for handling audio recording, streaming, and transcription tasks via one or more
314
+ WebSocket connections.
315
+
316
+ Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
317
+ to send audio data for transcription to one or more servers, and receive transcribed text segments.
318
+ Args:
319
+ clients (list): one or more previously initialized Client instances
320
+
321
+ Attributes:
322
+ clients (list): the underlying Client instances responsible for handling WebSocket connections.
323
+ """
324
+ def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav", mute_audio_playback=False):
325
+ self.clients = clients
326
+ if not self.clients:
327
+ raise Exception("At least one client is required.")
328
+ self.chunk = 4096
329
+ self.format = pyaudio.paInt16
330
+ self.channels = 1
331
+ self.rate = 16000
332
+ self.record_seconds = 60000
333
+ self.save_output_recording = save_output_recording
334
+ self.output_recording_filename = output_recording_filename
335
+ self.mute_audio_playback = mute_audio_playback
336
+ self.frames = b""
337
+ self.p = pyaudio.PyAudio()
338
+ try:
339
+ self.stream = self.p.open(
340
+ format=self.format,
341
+ channels=self.channels,
342
+ rate=self.rate,
343
+ input=True,
344
+ frames_per_buffer=self.chunk,
345
+ )
346
+ except OSError as error:
347
+ print(f"[WARN]: Unable to access microphone. {error}")
348
+ self.stream = None
349
+
350
+ def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
351
+ """
352
+ Start the transcription process.
353
+
354
+ Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
355
+ to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
356
+ will be played and streamed to the server; otherwise, it will perform live recording.
357
+
358
+ Args:
359
+ audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
360
+
361
+ """
362
+ assert sum(
363
+ source is not None for source in [audio, rtsp_url, hls_url]
364
+ ) <= 1, 'You must provide only one selected source'
365
+
366
+ print("[INFO]: Waiting for server ready ...")
367
+ for client in self.clients:
368
+ while not client.recording:
369
+ if client.waiting or client.server_error:
370
+ self.close_all_clients()
371
+ return
372
+
373
+ print("[INFO]: Server Ready!")
374
+ if hls_url is not None:
375
+ self.process_hls_stream(hls_url, save_file)
376
+ elif audio is not None:
377
+ resampled_file = utils.resample(audio)
378
+ self.play_file(resampled_file)
379
+ elif rtsp_url is not None:
380
+ self.process_rtsp_stream(rtsp_url)
381
+ else:
382
+ self.record()
383
+
384
+ def close_all_clients(self):
385
+ """Closes all client websockets."""
386
+ for client in self.clients:
387
+ client.close_websocket()
388
+
389
+ def write_all_clients_srt(self):
390
+ """Writes out .srt files for all clients."""
391
+ for client in self.clients:
392
+ client.write_srt_file(client.srt_file_path)
393
+
394
+ def multicast_packet(self, packet, unconditional=False):
395
+ """
396
+ Sends an identical packet via all clients.
397
+
398
+ Args:
399
+ packet (bytes): The audio data packet in bytes to be sent.
400
+ unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
401
+ """
402
+ for client in self.clients:
403
+ if (unconditional or client.recording):
404
+ client.send_packet_to_server(packet)
405
+
406
+ def play_file(self, filename):
407
+ """
408
+ Play an audio file and send it to the server for processing.
409
+
410
+ Reads an audio file, plays it through the audio output, and simultaneously sends
411
+ the audio data to the server for processing. It uses PyAudio to create an audio
412
+ stream for playback. The audio data is read from the file in chunks, converted to
413
+ floating-point format, and sent to the server using WebSocket communication.
414
+ This method is typically used when you want to process pre-recorded audio and send it
415
+ to the server in real-time.
416
+
417
+ Args:
418
+ filename (str): The path to the audio file to be played and sent to the server.
419
+ """
420
+
421
+ # read audio and create pyaudio stream
422
+ with wave.open(filename, "rb") as wavfile:
423
+ self.stream = self.p.open(
424
+ format=self.p.get_format_from_width(wavfile.getsampwidth()),
425
+ channels=wavfile.getnchannels(),
426
+ rate=wavfile.getframerate(),
427
+ input=True,
428
+ output=True,
429
+ frames_per_buffer=self.chunk,
430
+ )
431
+ chunk_duration = self.chunk / float(wavfile.getframerate())
432
+ try:
433
+ while any(client.recording for client in self.clients):
434
+ data = wavfile.readframes(self.chunk)
435
+ if data == b"":
436
+ break
437
+
438
+ audio_array = self.bytes_to_float_array(data)
439
+ self.multicast_packet(audio_array.tobytes())
440
+ if self.mute_audio_playback:
441
+ time.sleep(chunk_duration)
442
+ else:
443
+ self.stream.write(data)
444
+
445
+ wavfile.close()
446
+
447
+ for client in self.clients:
448
+ client.wait_before_disconnect()
449
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
450
+ self.write_all_clients_srt()
451
+ self.stream.close()
452
+ self.close_all_clients()
453
+
454
+ except KeyboardInterrupt:
455
+ wavfile.close()
456
+ self.stream.stop_stream()
457
+ self.stream.close()
458
+ self.p.terminate()
459
+ self.close_all_clients()
460
+ self.write_all_clients_srt()
461
+ print("[INFO]: Keyboard interrupt.")
462
+
463
+ def process_rtsp_stream(self, rtsp_url):
464
+ """
465
+ Connect to an RTSP source, process the audio stream, and send it for transcription.
466
+
467
+ Args:
468
+ rtsp_url (str): The URL of the RTSP stream source.
469
+ """
470
+ print("[INFO]: Connecting to RTSP stream...")
471
+ try:
472
+ container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
473
+ self.process_av_stream(container, stream_type="RTSP")
474
+ except Exception as e:
475
+ print(f"[ERROR]: Failed to process RTSP stream: {e}")
476
+ finally:
477
+ for client in self.clients:
478
+ client.wait_before_disconnect()
479
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
480
+ self.close_all_clients()
481
+ self.write_all_clients_srt()
482
+ print("[INFO]: RTSP stream processing finished.")
483
+
484
+ def process_hls_stream(self, hls_url, save_file=None):
485
+ """
486
+ Connect to an HLS source, process the audio stream, and send it for transcription.
487
+
488
+ Args:
489
+ hls_url (str): The URL of the HLS stream source.
490
+ save_file (str, optional): Local path to save the network stream.
491
+ """
492
+ print("[INFO]: Connecting to HLS stream...")
493
+ try:
494
+ container = av.open(hls_url, format="hls")
495
+ self.process_av_stream(container, stream_type="HLS", save_file=save_file)
496
+ except Exception as e:
497
+ print(f"[ERROR]: Failed to process HLS stream: {e}")
498
+ finally:
499
+ for client in self.clients:
500
+ client.wait_before_disconnect()
501
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
502
+ self.close_all_clients()
503
+ self.write_all_clients_srt()
504
+ print("[INFO]: HLS stream processing finished.")
505
+
506
+ def process_av_stream(self, container, stream_type, save_file=None):
507
+ """
508
+ Process an AV container stream and send audio packets to the server.
509
+
510
+ Args:
511
+ container (av.container.InputContainer): The input container to process.
512
+ stream_type (str): The type of stream being processed ("RTSP" or "HLS").
513
+ save_file (str, optional): Local path to save the stream. Default is None.
514
+ """
515
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
516
+ if not audio_stream:
517
+ print(f"[ERROR]: No audio stream found in {stream_type} source.")
518
+ return
519
+
520
+ output_container = None
521
+ if save_file:
522
+ output_container = av.open(save_file, mode="w")
523
+ output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
524
+
525
+ try:
526
+ for packet in container.demux(audio_stream):
527
+ for frame in packet.decode():
528
+ audio_data = frame.to_ndarray().tobytes()
529
+ self.multicast_packet(audio_data)
530
+
531
+ if save_file:
532
+ output_container.mux(frame)
533
+ except Exception as e:
534
+ print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
535
+ finally:
536
+ # Wait for server to send any leftover transcription.
537
+ time.sleep(5)
538
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
539
+ if output_container:
540
+ output_container.close()
541
+ container.close()
542
+
543
+ def save_chunk(self, n_audio_file):
544
+ """
545
+ Saves the current audio frames to a WAV file in a separate thread.
546
+
547
+ Args:
548
+ n_audio_file (int): The index of the audio file which determines the filename.
549
+ This helps in maintaining the order and uniqueness of each chunk.
550
+ """
551
+ t = threading.Thread(
552
+ target=self.write_audio_frames_to_file,
553
+ args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
554
+ )
555
+ t.start()
556
+
557
+ def finalize_recording(self, n_audio_file):
558
+ """
559
+ Finalizes the recording process by saving any remaining audio frames,
560
+ closing the audio stream, and terminating the process.
561
+
562
+ Args:
563
+ n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
564
+ This index is incremented before use if the last chunk is saved.
565
+ """
566
+ if self.save_output_recording and len(self.frames):
567
+ self.write_audio_frames_to_file(
568
+ self.frames[:], f"chunks/{n_audio_file}.wav"
569
+ )
570
+ n_audio_file += 1
571
+ self.stream.stop_stream()
572
+ self.stream.close()
573
+ self.p.terminate()
574
+ self.close_all_clients()
575
+ if self.save_output_recording:
576
+ self.write_output_recording(n_audio_file)
577
+ self.write_all_clients_srt()
578
+
579
+ def record(self):
580
+ """
581
+ Record audio data from the input stream and save it to a WAV file.
582
+
583
+ Continuously records audio data from the input stream, sends it to the server via a WebSocket
584
+ connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
585
+ the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
586
+
587
+ Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
588
+ The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
589
+ The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
590
+ the method combines all the saved audio chunks into the specified `out_file`.
591
+ """
592
+ n_audio_file = 0
593
+ if self.save_output_recording:
594
+ if os.path.exists("chunks"):
595
+ shutil.rmtree("chunks")
596
+ os.makedirs("chunks")
597
+ try:
598
+ for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
599
+ if not any(client.recording for client in self.clients):
600
+ break
601
+ data = self.stream.read(self.chunk, exception_on_overflow=False)
602
+ self.frames += data
603
+
604
+ audio_array = self.bytes_to_float_array(data)
605
+
606
+ self.multicast_packet(audio_array.tobytes())
607
+
608
+ # save frames if more than a minute
609
+ if len(self.frames) > 60 * self.rate:
610
+ if self.save_output_recording:
611
+ self.save_chunk(n_audio_file)
612
+ n_audio_file += 1
613
+ self.frames = b""
614
+ self.write_all_clients_srt()
615
+
616
+ except KeyboardInterrupt:
617
+ self.finalize_recording(n_audio_file)
618
+
619
+ def write_audio_frames_to_file(self, frames, file_name):
620
+ """
621
+ Write audio frames to a WAV file.
622
+
623
+ The WAV file is created or overwritten with the specified name. The audio frames should be
624
+ in the correct format and match the specified channel, sample width, and sample rate.
625
+
626
+ Args:
627
+ frames (bytes): The audio frames to be written to the file.
628
+ file_name (str): The name of the WAV file to which the frames will be written.
629
+
630
+ """
631
+ with wave.open(file_name, "wb") as wavfile:
632
+ wavfile: wave.Wave_write
633
+ wavfile.setnchannels(self.channels)
634
+ wavfile.setsampwidth(2)
635
+ wavfile.setframerate(self.rate)
636
+ wavfile.writeframes(frames)
637
+
638
+ def write_output_recording(self, n_audio_file):
639
+ """
640
+ Combine and save recorded audio chunks into a single WAV file.
641
+
642
+ The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
643
+ file, appends its audio data to the final recording, and then deletes the chunk file. After combining
644
+ and saving, the final recording is stored in the specified `out_file`.
645
+
646
+
647
+ Args:
648
+ n_audio_file (int): The number of audio chunk files to combine.
649
+ out_file (str): The name of the output WAV file to save the final recording.
650
+
651
+ """
652
+ input_files = [
653
+ f"chunks/{i}.wav"
654
+ for i in range(n_audio_file)
655
+ if os.path.exists(f"chunks/{i}.wav")
656
+ ]
657
+ with wave.open(self.output_recording_filename, "wb") as wavfile:
658
+ wavfile: wave.Wave_write
659
+ wavfile.setnchannels(self.channels)
660
+ wavfile.setsampwidth(2)
661
+ wavfile.setframerate(self.rate)
662
+ for in_file in input_files:
663
+ with wave.open(in_file, "rb") as wav_in:
664
+ while True:
665
+ data = wav_in.readframes(self.chunk)
666
+ if data == b"":
667
+ break
668
+ wavfile.writeframes(data)
669
+ # remove this file
670
+ os.remove(in_file)
671
+ wavfile.close()
672
+ # clean up temporary directory to store chunks
673
+ if os.path.exists("chunks"):
674
+ shutil.rmtree("chunks")
675
+
676
+ @staticmethod
677
+ def bytes_to_float_array(audio_bytes):
678
+ """
679
+ Convert audio data from bytes to a NumPy float array.
680
+
681
+ It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
682
+ have values between -1 and 1.
683
+
684
+ Args:
685
+ audio_bytes (bytes): Audio data in bytes.
686
+
687
+ Returns:
688
+ np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
689
+ """
690
+ raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
691
+ return raw_data.astype(np.float32) / 32768.0
692
+
693
+
694
+ class TranscriptionClient(TranscriptionTeeClient):
695
+ """
696
+ Client for handling audio transcription tasks via a single WebSocket connection.
697
+
698
+ Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
699
+ to send audio data for transcription to a server and receive transcribed text segments.
700
+
701
+ Args:
702
+ host (str): The hostname or IP address of the server.
703
+ port (int): The port number to connect to on the server.
704
+ lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
705
+ translate (bool, optional): If True, the task will be translation instead of transcription. Default is False.
706
+ model (str, optional): The whisper model to use (e.g., "small", "base"). Default is "small".
707
+ use_vad (bool, optional): Whether to enable voice activity detection. Default is True.
708
+ save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
709
+ output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
710
+ output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
711
+ log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
712
+ max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
713
+ max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
714
+ mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
715
+ send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
716
+ no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
717
+ clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
718
+ same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
719
+ transcription_callback (callable, optional): A callback function to handle transcription results. Default is None.
720
+
721
+ Attributes:
722
+ client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
723
+
724
+ Example:
725
+ To create a TranscriptionClient and start transcription on microphone audio:
726
+ ```python
727
+ transcription_client = TranscriptionClient(host="localhost", port=9090)
728
+ transcription_client()
729
+ ```
730
+ """
731
+ def __init__(
732
+ self,
733
+ host,
734
+ port,
735
+ lang=None,
736
+ translate=False,
737
+ model="small",
738
+ use_vad=True,
739
+ use_wss=False,
740
+ save_output_recording=False,
741
+ output_recording_filename="./output_recording.wav",
742
+ output_transcription_path="./output.srt",
743
+ log_transcription=True,
744
+ max_clients=4,
745
+ max_connection_time=600,
746
+ mute_audio_playback=False,
747
+ send_last_n_segments=10,
748
+ no_speech_thresh=0.45,
749
+ clip_audio=False,
750
+ same_output_threshold=10,
751
+ transcription_callback=None,
752
+ ):
753
+ self.client = Client(
754
+ host,
755
+ port,
756
+ lang,
757
+ translate,
758
+ model,
759
+ srt_file_path=output_transcription_path,
760
+ use_vad=use_vad,
761
+ use_wss=use_wss,
762
+ log_transcription=log_transcription,
763
+ max_clients=max_clients,
764
+ max_connection_time=max_connection_time,
765
+ send_last_n_segments=send_last_n_segments,
766
+ no_speech_thresh=no_speech_thresh,
767
+ clip_audio=clip_audio,
768
+ same_output_threshold=same_output_threshold,
769
+ transcription_callback=transcription_callback,
770
+ )
771
+
772
+ if save_output_recording and not output_recording_filename.endswith(".wav"):
773
+ raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
774
+ if not output_transcription_path.endswith(".srt"):
775
+ raise ValueError(f"Please provide a valid `output_transcription_path`: {output_transcription_path}. The file extension should be `.srt`.")
776
+ TranscriptionTeeClient.__init__(
777
+ self,
778
+ [self.client],
779
+ save_output_recording=save_output_recording,
780
+ output_recording_filename=output_recording_filename,
781
+ mute_audio_playback=mute_audio_playback
782
+ )
whisper_live/server.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import threading
4
+ import json
5
+ import functools
6
+ import logging
7
+ from enum import Enum
8
+ from typing import List, Optional
9
+
10
+ import numpy as np
11
+ from websockets.sync.server import serve
12
+ from websockets.exceptions import ConnectionClosed
13
+ from whisper_live.vad import VoiceActivityDetector
14
+ from whisper_live.backend.base import ServeClientBase
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+
19
+ class ClientManager:
20
+ def __init__(self, max_clients=4, max_connection_time=600):
21
+ """
22
+ Initializes the ClientManager with specified limits on client connections and connection durations.
23
+
24
+ Args:
25
+ max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
26
+ max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
27
+ to 600 seconds (10 minutes).
28
+ """
29
+ self.clients = {}
30
+ self.start_times = {}
31
+ self.max_clients = max_clients
32
+ self.max_connection_time = max_connection_time
33
+
34
+ def add_client(self, websocket, client):
35
+ """
36
+ Adds a client and their connection start time to the tracking dictionaries.
37
+
38
+ Args:
39
+ websocket: The websocket associated with the client to add.
40
+ client: The client object to be added and tracked.
41
+ """
42
+ self.clients[websocket] = client
43
+ self.start_times[websocket] = time.time()
44
+
45
+ def get_client(self, websocket):
46
+ """
47
+ Retrieves a client associated with the given websocket.
48
+
49
+ Args:
50
+ websocket: The websocket associated with the client to retrieve.
51
+
52
+ Returns:
53
+ The client object if found, False otherwise.
54
+ """
55
+ if websocket in self.clients:
56
+ return self.clients[websocket]
57
+ return False
58
+
59
+ def remove_client(self, websocket):
60
+ """
61
+ Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
62
+ client if necessary.
63
+
64
+ Args:
65
+ websocket: The websocket associated with the client to be removed.
66
+ """
67
+ client = self.clients.pop(websocket, None)
68
+ if client:
69
+ client.cleanup()
70
+ self.start_times.pop(websocket, None)
71
+
72
+ def get_wait_time(self):
73
+ """
74
+ Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
75
+
76
+ Returns:
77
+ The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
78
+ """
79
+ wait_time = None
80
+ for start_time in self.start_times.values():
81
+ current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
82
+ if wait_time is None or current_client_time_remaining < wait_time:
83
+ wait_time = current_client_time_remaining
84
+ return wait_time / 60 if wait_time is not None else 0
85
+
86
+ def is_server_full(self, websocket, options):
87
+ """
88
+ Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
89
+
90
+ Args:
91
+ websocket: The websocket of the client attempting to connect.
92
+ options: A dictionary of options that may include the client's unique identifier.
93
+
94
+ Returns:
95
+ True if the server is full, False otherwise.
96
+ """
97
+ if len(self.clients) >= self.max_clients:
98
+ wait_time = self.get_wait_time()
99
+ response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
100
+ websocket.send(json.dumps(response))
101
+ return True
102
+ return False
103
+
104
+ def is_client_timeout(self, websocket):
105
+ """
106
+ Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
107
+
108
+ Args:
109
+ websocket: The websocket associated with the client to check.
110
+
111
+ Returns:
112
+ True if the client's connection time has exceeded the maximum limit, False otherwise.
113
+ """
114
+ elapsed_time = time.time() - self.start_times[websocket]
115
+ if elapsed_time >= self.max_connection_time:
116
+ self.clients[websocket].disconnect()
117
+ logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
118
+ return True
119
+ return False
120
+
121
+
122
+ class BackendType(Enum):
123
+ FASTER_WHISPER = "faster_whisper"
124
+ TENSORRT = "tensorrt"
125
+ OPENVINO = "openvino"
126
+
127
+ @staticmethod
128
+ def valid_types() -> List[str]:
129
+ return [backend_type.value for backend_type in BackendType]
130
+
131
+ @staticmethod
132
+ def is_valid(backend: str) -> bool:
133
+ return backend in BackendType.valid_types()
134
+
135
+ def is_faster_whisper(self) -> bool:
136
+ return self == BackendType.FASTER_WHISPER
137
+
138
+ def is_tensorrt(self) -> bool:
139
+ return self == BackendType.TENSORRT
140
+
141
+ def is_openvino(self) -> bool:
142
+ return self == BackendType.OPENVINO
143
+
144
+
145
+ class TranscriptionServer:
146
+ RATE = 16000
147
+
148
+ def __init__(self):
149
+ self.client_manager = None
150
+ self.no_voice_activity_chunks = 0
151
+ self.use_vad = True
152
+ self.single_model = False
153
+
154
+ def initialize_client(
155
+ self, websocket, options, faster_whisper_custom_model_path,
156
+ whisper_tensorrt_path, trt_multilingual, trt_py_session=False,
157
+ ):
158
+ client: Optional[ServeClientBase] = None
159
+
160
+ if self.backend.is_tensorrt():
161
+ try:
162
+ from whisper_live.backend.trt_backend import ServeClientTensorRT
163
+ client = ServeClientTensorRT(
164
+ websocket,
165
+ multilingual=trt_multilingual,
166
+ language=options["language"],
167
+ task=options["task"],
168
+ client_uid=options["uid"],
169
+ model=whisper_tensorrt_path,
170
+ single_model=self.single_model,
171
+ use_py_session=trt_py_session,
172
+ send_last_n_segments=options.get("send_last_n_segments", 10),
173
+ no_speech_thresh=options.get("no_speech_thresh", 0.45),
174
+ clip_audio=options.get("clip_audio", False),
175
+ same_output_threshold=options.get("same_output_threshold", 10),
176
+ )
177
+ logging.info("Running TensorRT backend.")
178
+ except Exception as e:
179
+ logging.error(f"TensorRT-LLM not supported: {e}")
180
+ self.client_uid = options["uid"]
181
+ websocket.send(json.dumps({
182
+ "uid": self.client_uid,
183
+ "status": "WARNING",
184
+ "message": "TensorRT-LLM not supported on Server yet. "
185
+ "Reverting to available backend: 'faster_whisper'"
186
+ }))
187
+ self.backend = BackendType.FASTER_WHISPER
188
+
189
+ if self.backend.is_openvino():
190
+ try:
191
+ from whisper_live.backend.openvino_backend import ServeClientOpenVINO
192
+ client = ServeClientOpenVINO(
193
+ websocket,
194
+ language=options["language"],
195
+ task=options["task"],
196
+ client_uid=options["uid"],
197
+ model=options["model"],
198
+ single_model=self.single_model,
199
+ send_last_n_segments=options.get("send_last_n_segments", 10),
200
+ no_speech_thresh=options.get("no_speech_thresh", 0.45),
201
+ clip_audio=options.get("clip_audio", False),
202
+ same_output_threshold=options.get("same_output_threshold", 10),
203
+ )
204
+ logging.info("Running OpenVINO backend.")
205
+ except Exception as e:
206
+ logging.error(f"OpenVINO not supported: {e}")
207
+ self.backend = BackendType.FASTER_WHISPER
208
+ self.client_uid = options["uid"]
209
+ websocket.send(json.dumps({
210
+ "uid": self.client_uid,
211
+ "status": "WARNING",
212
+ "message": "OpenVINO not supported on Server yet. "
213
+ "Reverting to available backend: 'faster_whisper'"
214
+ }))
215
+
216
+ try:
217
+ if self.backend.is_faster_whisper():
218
+ from whisper_live.backend.faster_whisper_backend import ServeClientFasterWhisper
219
+ if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path):
220
+ logging.info(f"Using custom model {faster_whisper_custom_model_path}")
221
+ options["model"] = faster_whisper_custom_model_path
222
+ client = ServeClientFasterWhisper(
223
+ websocket,
224
+ language=options["language"],
225
+ task=options["task"],
226
+ client_uid=options["uid"],
227
+ model=options["model"],
228
+ initial_prompt=options.get("initial_prompt"),
229
+ vad_parameters=options.get("vad_parameters"),
230
+ use_vad=self.use_vad,
231
+ single_model=self.single_model,
232
+ send_last_n_segments=options.get("send_last_n_segments", 10),
233
+ no_speech_thresh=options.get("no_speech_thresh", 0.45),
234
+ clip_audio=options.get("clip_audio", False),
235
+ same_output_threshold=options.get("same_output_threshold", 10),
236
+ )
237
+
238
+ logging.info("Running faster_whisper backend.")
239
+ except Exception as e:
240
+ logging.error(e)
241
+ return
242
+
243
+ if client is None:
244
+ raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
245
+
246
+ self.client_manager.add_client(websocket, client)
247
+
248
+ def get_audio_from_websocket(self, websocket):
249
+ """
250
+ Receives audio buffer from websocket and creates a numpy array out of it.
251
+
252
+ Args:
253
+ websocket: The websocket to receive audio from.
254
+
255
+ Returns:
256
+ A numpy array containing the audio.
257
+ """
258
+ frame_data = websocket.recv()
259
+ if frame_data == b"END_OF_AUDIO":
260
+ return False
261
+ return np.frombuffer(frame_data, dtype=np.float32)
262
+
263
+ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
264
+ whisper_tensorrt_path, trt_multilingual, trt_py_session=False):
265
+ try:
266
+ logging.info("New client connected")
267
+ options = websocket.recv()
268
+ options = json.loads(options)
269
+
270
+ if self.client_manager is None:
271
+ max_clients = options.get('max_clients', 4)
272
+ max_connection_time = options.get('max_connection_time', 600)
273
+ self.client_manager = ClientManager(max_clients, max_connection_time)
274
+
275
+ self.use_vad = options.get('use_vad')
276
+ if self.client_manager.is_server_full(websocket, options):
277
+ websocket.close()
278
+ return False # Indicates that the connection should not continue
279
+
280
+ if self.backend.is_tensorrt():
281
+ self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
282
+ self.initialize_client(websocket, options, faster_whisper_custom_model_path,
283
+ whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session)
284
+ return True
285
+ except json.JSONDecodeError:
286
+ logging.error("Failed to decode JSON from client")
287
+ return False
288
+ except ConnectionClosed:
289
+ logging.info("Connection closed by client")
290
+ return False
291
+ except Exception as e:
292
+ logging.error(f"Error during new connection initialization: {str(e)}")
293
+ return False
294
+
295
+ def process_audio_frames(self, websocket):
296
+ frame_np = self.get_audio_from_websocket(websocket)
297
+ client = self.client_manager.get_client(websocket)
298
+ if frame_np is False:
299
+ if self.backend.is_tensorrt():
300
+ client.set_eos(True)
301
+ return False
302
+
303
+ if self.backend.is_tensorrt():
304
+ voice_active = self.voice_activity(websocket, frame_np)
305
+ if voice_active:
306
+ self.no_voice_activity_chunks = 0
307
+ client.set_eos(False)
308
+ if self.use_vad and not voice_active:
309
+ return True
310
+
311
+ client.add_frames(frame_np)
312
+ return True
313
+
314
+ def recv_audio(self,
315
+ websocket,
316
+ backend: BackendType = BackendType.FASTER_WHISPER,
317
+ faster_whisper_custom_model_path=None,
318
+ whisper_tensorrt_path=None,
319
+ trt_multilingual=False,
320
+ trt_py_session=False):
321
+ """
322
+ Receive audio chunks from a client in an infinite loop.
323
+
324
+ Continuously receives audio frames from a connected client
325
+ over a WebSocket connection. It processes the audio frames using a
326
+ voice activity detection (VAD) model to determine if they contain speech
327
+ or not. If the audio frame contains speech, it is added to the client's
328
+ audio data for ASR.
329
+ If the maximum number of clients is reached, the method sends a
330
+ "WAIT" status to the client, indicating that they should wait
331
+ until a slot is available.
332
+ If a client's connection exceeds the maximum allowed time, it will
333
+ be disconnected, and the client's resources will be cleaned up.
334
+
335
+ Args:
336
+ websocket (WebSocket): The WebSocket connection for the client.
337
+ backend (str): The backend to run the server with.
338
+ faster_whisper_custom_model_path (str): path to custom faster whisper model.
339
+ whisper_tensorrt_path (str): Required for tensorrt backend.
340
+ trt_multilingual(bool): Only used for tensorrt, True if multilingual model.
341
+
342
+ Raises:
343
+ Exception: If there is an error during the audio frame processing.
344
+ """
345
+ self.backend = backend
346
+ if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
347
+ whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session):
348
+ return
349
+
350
+ try:
351
+ while not self.client_manager.is_client_timeout(websocket):
352
+ if not self.process_audio_frames(websocket):
353
+ break
354
+ except ConnectionClosed:
355
+ logging.info("Connection closed by client")
356
+ except Exception as e:
357
+ logging.error(f"Unexpected error: {str(e)}")
358
+ finally:
359
+ if self.client_manager.get_client(websocket):
360
+ self.cleanup(websocket)
361
+ websocket.close()
362
+ del websocket
363
+
364
+ def run(self,
365
+ host,
366
+ port=9090,
367
+ backend="tensorrt",
368
+ faster_whisper_custom_model_path=None,
369
+ whisper_tensorrt_path=None,
370
+ trt_multilingual=False,
371
+ trt_py_session=False,
372
+ single_model=False):
373
+ """
374
+ Run the transcription server.
375
+
376
+ Args:
377
+ host (str): The host address to bind the server.
378
+ port (int): The port number to bind the server.
379
+ """
380
+ if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path):
381
+ raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.")
382
+ if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path):
383
+ raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.")
384
+ if single_model:
385
+ if faster_whisper_custom_model_path or whisper_tensorrt_path:
386
+ logging.info("Custom model option was provided. Switching to single model mode.")
387
+ self.single_model = True
388
+ # TODO: load model initially
389
+ else:
390
+ logging.info("Single model mode currently only works with custom models.")
391
+ if not BackendType.is_valid(backend):
392
+ raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
393
+ with serve(
394
+ functools.partial(
395
+ self.recv_audio,
396
+ backend=BackendType(backend),
397
+ faster_whisper_custom_model_path=faster_whisper_custom_model_path,
398
+ whisper_tensorrt_path=whisper_tensorrt_path,
399
+ trt_multilingual=trt_multilingual,
400
+ trt_py_session=trt_py_session,
401
+ ),
402
+ host,
403
+ port
404
+ ) as server:
405
+ server.serve_forever()
406
+
407
+ def voice_activity(self, websocket, frame_np):
408
+ """
409
+ Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
410
+
411
+ This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
412
+ contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
413
+ it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
414
+ speech detection to improve subsequent processing steps.
415
+
416
+ Args:
417
+ websocket: The websocket associated with the current client. Used to retrieve the client object
418
+ from the client manager for state management.
419
+ frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
420
+ the audio data for the current frame.
421
+
422
+ Returns:
423
+ bool: True if voice activity is detected in the current frame, False otherwise. When returning False
424
+ after detecting no voice activity for more than three consecutive frames, it also triggers the
425
+ end-of-speech (EOS) flag for the client.
426
+ """
427
+ if not self.vad_detector(frame_np):
428
+ self.no_voice_activity_chunks += 1
429
+ if self.no_voice_activity_chunks > 3:
430
+ client = self.client_manager.get_client(websocket)
431
+ if not client.eos:
432
+ client.set_eos(True)
433
+ time.sleep(0.1) # Sleep 100m; wait some voice activity.
434
+ return False
435
+ return True
436
+
437
+ def cleanup(self, websocket):
438
+ """
439
+ Cleans up resources associated with a given client's websocket.
440
+
441
+ Args:
442
+ websocket: The websocket associated with the client to be cleaned up.
443
+ """
444
+ if self.client_manager.get_client(websocket):
445
+ self.client_manager.remove_client(websocket)
446
+
whisper_live/transcriber/__init__.py ADDED
File without changes
whisper_live/transcriber/tensorrt_utils.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ import os
17
+ from collections import defaultdict
18
+ from functools import lru_cache
19
+ from pathlib import Path
20
+ from subprocess import CalledProcessError, run
21
+ from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
22
+
23
+ import kaldialign
24
+ import numpy as np
25
+ import soundfile
26
+ import av
27
+ import wave
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from whisper_live.utils import resample
31
+
32
+
33
+ Pathlike = Union[str, Path]
34
+
35
+ SAMPLE_RATE = 16000
36
+ N_FFT = 400
37
+ HOP_LENGTH = 160
38
+ CHUNK_LENGTH = 30
39
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
40
+
41
+
42
+ def load_audio(file: str, sr: int = 16000):
43
+ """
44
+ Open an audio file, resample it, and read as a mono waveform.
45
+
46
+ Parameters
47
+ ----------
48
+ file: str
49
+ The audio file to open.
50
+
51
+ sr: int
52
+ The sample rate to resample the audio if necessary.
53
+
54
+ Returns
55
+ -------
56
+ A NumPy array containing the audio waveform, in float32 dtype.
57
+ """
58
+ resampled_file = resample(file, sr)
59
+
60
+ with wave.open(resampled_file, "rb") as wav_file:
61
+ num_frames = wav_file.getnframes()
62
+ raw_data = wav_file.readframes(num_frames)
63
+
64
+ audio_data = np.frombuffer(raw_data, dtype=np.int16)
65
+
66
+ audio_data = audio_data.astype(np.float32) / 32768.0
67
+
68
+ return audio_data
69
+
70
+
71
+ def load_audio_wav_format(wav_path):
72
+ # make sure audio in .wav format
73
+ assert wav_path.endswith(
74
+ '.wav'), f"Only support .wav format, but got {wav_path}"
75
+ waveform, sample_rate = soundfile.read(wav_path)
76
+ assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
77
+ return waveform, sample_rate
78
+
79
+
80
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
81
+ """
82
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
83
+ """
84
+ if torch.is_tensor(array):
85
+ if array.shape[axis] > length:
86
+ array = array.index_select(dim=axis,
87
+ index=torch.arange(length,
88
+ device=array.device))
89
+
90
+ if array.shape[axis] < length:
91
+ pad_widths = [(0, 0)] * array.ndim
92
+ pad_widths[axis] = (0, length - array.shape[axis])
93
+ array = F.pad(array,
94
+ [pad for sizes in pad_widths[::-1] for pad in sizes])
95
+ else:
96
+ if array.shape[axis] > length:
97
+ array = array.take(indices=range(length), axis=axis)
98
+
99
+ if array.shape[axis] < length:
100
+ pad_widths = [(0, 0)] * array.ndim
101
+ pad_widths[axis] = (0, length - array.shape[axis])
102
+ array = np.pad(array, pad_widths)
103
+
104
+ return array
105
+
106
+
107
+ @lru_cache(maxsize=None)
108
+ def mel_filters(device,
109
+ n_mels: int,
110
+ mel_filters_dir: str = None) -> torch.Tensor:
111
+ """
112
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
113
+ Allows decoupling librosa dependency; saved using:
114
+
115
+ np.savez_compressed(
116
+ "mel_filters.npz",
117
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
118
+ )
119
+ """
120
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
121
+ if mel_filters_dir is None:
122
+ mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
123
+ "mel_filters.npz")
124
+ else:
125
+ mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
126
+ with np.load(mel_filters_path) as f:
127
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
128
+
129
+
130
+ def log_mel_spectrogram(
131
+ audio: Union[str, np.ndarray, torch.Tensor],
132
+ n_mels: int,
133
+ padding: int = 0,
134
+ device: Optional[Union[str, torch.device]] = None,
135
+ return_duration: bool = False,
136
+ mel_filters_dir: str = None,
137
+ ):
138
+ """
139
+ Compute the log-Mel spectrogram of
140
+
141
+ Parameters
142
+ ----------
143
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
144
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
145
+
146
+ n_mels: int
147
+ The number of Mel-frequency filters, only 80 and 128 are supported
148
+
149
+ padding: int
150
+ Number of zero samples to pad to the right
151
+
152
+ device: Optional[Union[str, torch.device]]
153
+ If given, the audio tensor is moved to this device before STFT
154
+
155
+ Returns
156
+ -------
157
+ torch.Tensor, shape = (80 or 128, n_frames)
158
+ A Tensor that contains the Mel spectrogram
159
+ """
160
+ if not torch.is_tensor(audio):
161
+ if isinstance(audio, str):
162
+ if audio.endswith('.wav'):
163
+ audio, _ = load_audio_wav_format(audio)
164
+ else:
165
+ audio = load_audio(audio)
166
+ assert isinstance(audio,
167
+ np.ndarray), f"Unsupported audio type: {type(audio)}"
168
+ duration = audio.shape[-1] / SAMPLE_RATE
169
+ audio = pad_or_trim(audio, N_SAMPLES)
170
+ audio = audio.astype(np.float32)
171
+ audio = torch.from_numpy(audio)
172
+
173
+ if device is not None:
174
+ audio = audio.to(device)
175
+ if padding > 0:
176
+ audio = F.pad(audio, (0, padding))
177
+ window = torch.hann_window(N_FFT).to(audio.device)
178
+ stft = torch.stft(audio,
179
+ N_FFT,
180
+ HOP_LENGTH,
181
+ window=window,
182
+ return_complex=True)
183
+ magnitudes = stft[..., :-1].abs()**2
184
+
185
+ filters = mel_filters(audio.device, n_mels, mel_filters_dir)
186
+ mel_spec = filters @ magnitudes
187
+
188
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
189
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
190
+ log_spec = (log_spec + 4.0) / 4.0
191
+ if return_duration:
192
+ return log_spec, duration
193
+ else:
194
+ return log_spec
195
+
196
+
197
+ def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str,
198
+ str]]) -> None:
199
+ """Save predicted results and reference transcripts to a file.
200
+ https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
201
+ Args:
202
+ filename:
203
+ File to save the results to.
204
+ texts:
205
+ An iterable of tuples. The first element is the cur_id, the second is
206
+ the reference transcript and the third element is the predicted result.
207
+ Returns:
208
+ Return None.
209
+ """
210
+ with open(filename, "w") as f:
211
+ for cut_id, ref, hyp in texts:
212
+ print(f"{cut_id}:\tref={ref}", file=f)
213
+ print(f"{cut_id}:\thyp={hyp}", file=f)
214
+
215
+
216
+ def write_error_stats( # noqa: C901
217
+ f: TextIO,
218
+ test_set_name: str,
219
+ results: List[Tuple[str, str]],
220
+ enable_log: bool = True,
221
+ ) -> float:
222
+ """Write statistics based on predicted results and reference transcripts.
223
+ https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
224
+ It will write the following to the given file:
225
+
226
+ - WER
227
+ - number of insertions, deletions, substitutions, corrects and total
228
+ reference words. For example::
229
+
230
+ Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
231
+ reference words (2337 correct)
232
+
233
+ - The difference between the reference transcript and predicted result.
234
+ An instance is given below::
235
+
236
+ THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
237
+
238
+ The above example shows that the reference word is `EDISON`,
239
+ but it is predicted to `ADDISON` (a substitution error).
240
+
241
+ Another example is::
242
+
243
+ FOR THE FIRST DAY (SIR->*) I THINK
244
+
245
+ The reference word `SIR` is missing in the predicted
246
+ results (a deletion error).
247
+ results:
248
+ An iterable of tuples. The first element is the cur_id, the second is
249
+ the reference transcript and the third element is the predicted result.
250
+ enable_log:
251
+ If True, also print detailed WER to the console.
252
+ Otherwise, it is written only to the given file.
253
+ Returns:
254
+ Return None.
255
+ """
256
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
257
+ ins: Dict[str, int] = defaultdict(int)
258
+ dels: Dict[str, int] = defaultdict(int)
259
+
260
+ # `words` stores counts per word, as follows:
261
+ # corr, ref_sub, hyp_sub, ins, dels
262
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
263
+ num_corr = 0
264
+ ERR = "*"
265
+ for cut_id, ref, hyp in results:
266
+ ali = kaldialign.align(ref, hyp, ERR)
267
+ for ref_word, hyp_word in ali:
268
+ if ref_word == ERR:
269
+ ins[hyp_word] += 1
270
+ words[hyp_word][3] += 1
271
+ elif hyp_word == ERR:
272
+ dels[ref_word] += 1
273
+ words[ref_word][4] += 1
274
+ elif hyp_word != ref_word:
275
+ subs[(ref_word, hyp_word)] += 1
276
+ words[ref_word][1] += 1
277
+ words[hyp_word][2] += 1
278
+ else:
279
+ words[ref_word][0] += 1
280
+ num_corr += 1
281
+ ref_len = sum([len(r) for _, r, _ in results])
282
+ sub_errs = sum(subs.values())
283
+ ins_errs = sum(ins.values())
284
+ del_errs = sum(dels.values())
285
+ tot_errs = sub_errs + ins_errs + del_errs
286
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
287
+
288
+ if enable_log:
289
+ logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
290
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
291
+ f"{del_errs} del, {sub_errs} sub ]")
292
+
293
+ print(f"%WER = {tot_err_rate}", file=f)
294
+ print(
295
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
296
+ f"{sub_errs} substitutions, over {ref_len} reference "
297
+ f"words ({num_corr} correct)",
298
+ file=f,
299
+ )
300
+ print(
301
+ "Search below for sections starting with PER-UTT DETAILS:, "
302
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
303
+ file=f,
304
+ )
305
+
306
+ print("", file=f)
307
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
308
+ for cut_id, ref, hyp in results:
309
+ ali = kaldialign.align(ref, hyp, ERR)
310
+ combine_successive_errors = True
311
+ if combine_successive_errors:
312
+ ali = [[[x], [y]] for x, y in ali]
313
+ for i in range(len(ali) - 1):
314
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
315
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
316
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
317
+ ali[i] = [[], []]
318
+ ali = [[
319
+ list(filter(lambda a: a != ERR, x)),
320
+ list(filter(lambda a: a != ERR, y)),
321
+ ] for x, y in ali]
322
+ ali = list(filter(lambda x: x != [[], []], ali))
323
+ ali = [[
324
+ ERR if x == [] else " ".join(x),
325
+ ERR if y == [] else " ".join(y),
326
+ ] for x, y in ali]
327
+
328
+ print(
329
+ f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else
330
+ f"({ref_word}->{hyp_word})"
331
+ for ref_word, hyp_word in ali)),
332
+ file=f,
333
+ )
334
+
335
+ print("", file=f)
336
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
337
+
338
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()],
339
+ reverse=True):
340
+ print(f"{count} {ref} -> {hyp}", file=f)
341
+
342
+ print("", file=f)
343
+ print("DELETIONS: count ref", file=f)
344
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
345
+ print(f"{count} {ref}", file=f)
346
+
347
+ print("", file=f)
348
+ print("INSERTIONS: count hyp", file=f)
349
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
350
+ print(f"{count} {hyp}", file=f)
351
+
352
+ print("", file=f)
353
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp",
354
+ file=f)
355
+ for _, word, counts in sorted([(sum(v[1:]), k, v)
356
+ for k, v in words.items()],
357
+ reverse=True):
358
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
359
+ tot_errs = ref_sub + hyp_sub + ins + dels
360
+ ref_count = corr + ref_sub + dels
361
+ hyp_count = corr + hyp_sub + ins
362
+
363
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
364
+ return float(tot_err_rate)
whisper_live/transcriber/transcriber_faster_whisper.py ADDED
@@ -0,0 +1,1889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py
2
+
3
+ import itertools
4
+ import json
5
+ import logging
6
+ import os
7
+ import zlib
8
+
9
+ from dataclasses import asdict, dataclass
10
+ from inspect import signature
11
+ from math import ceil
12
+ from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
13
+ from warnings import warn
14
+
15
+ import ctranslate2
16
+ import numpy as np
17
+ import tokenizers
18
+
19
+ from tqdm import tqdm
20
+
21
+ from faster_whisper.audio import decode_audio, pad_or_trim
22
+ from faster_whisper.feature_extractor import FeatureExtractor
23
+ from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
24
+ from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
25
+ from faster_whisper.vad import (
26
+ SpeechTimestampsMap,
27
+ VadOptions,
28
+ collect_chunks,
29
+ get_speech_timestamps,
30
+ merge_segments,
31
+ )
32
+
33
+
34
+ @dataclass
35
+ class Word:
36
+ start: float
37
+ end: float
38
+ word: str
39
+ probability: float
40
+
41
+ def _asdict(self):
42
+ warn(
43
+ "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
44
+ DeprecationWarning,
45
+ 2,
46
+ )
47
+ return asdict(self)
48
+
49
+
50
+ @dataclass
51
+ class Segment:
52
+ id: int
53
+ seek: int
54
+ start: float
55
+ end: float
56
+ text: str
57
+ tokens: List[int]
58
+ avg_logprob: float
59
+ compression_ratio: float
60
+ no_speech_prob: float
61
+ words: Optional[List[Word]]
62
+ temperature: Optional[float]
63
+
64
+ def _asdict(self):
65
+ warn(
66
+ "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
67
+ DeprecationWarning,
68
+ 2,
69
+ )
70
+ return asdict(self)
71
+
72
+
73
+ @dataclass
74
+ class TranscriptionOptions:
75
+ beam_size: int
76
+ best_of: int
77
+ patience: float
78
+ length_penalty: float
79
+ repetition_penalty: float
80
+ no_repeat_ngram_size: int
81
+ log_prob_threshold: Optional[float]
82
+ no_speech_threshold: Optional[float]
83
+ compression_ratio_threshold: Optional[float]
84
+ condition_on_previous_text: bool
85
+ prompt_reset_on_temperature: float
86
+ temperatures: List[float]
87
+ initial_prompt: Optional[Union[str, Iterable[int]]]
88
+ prefix: Optional[str]
89
+ suppress_blank: bool
90
+ suppress_tokens: Optional[List[int]]
91
+ without_timestamps: bool
92
+ max_initial_timestamp: float
93
+ word_timestamps: bool
94
+ prepend_punctuations: str
95
+ append_punctuations: str
96
+ multilingual: bool
97
+ max_new_tokens: Optional[int]
98
+ clip_timestamps: Union[str, List[float]]
99
+ hallucination_silence_threshold: Optional[float]
100
+ hotwords: Optional[str]
101
+
102
+
103
+ @dataclass
104
+ class TranscriptionInfo:
105
+ language: str
106
+ language_probability: float
107
+ duration: float
108
+ duration_after_vad: float
109
+ all_language_probs: Optional[List[Tuple[str, float]]]
110
+ transcription_options: TranscriptionOptions
111
+ vad_options: VadOptions
112
+
113
+
114
+ class BatchedInferencePipeline:
115
+ def __init__(
116
+ self,
117
+ model,
118
+ ):
119
+ self.model: WhisperModel = model
120
+ self.last_speech_timestamp = 0.0
121
+
122
+ def forward(self, features, tokenizer, chunks_metadata, options):
123
+ encoder_output, outputs = self.generate_segment_batched(
124
+ features, tokenizer, options
125
+ )
126
+
127
+ segmented_outputs = []
128
+ segment_sizes = []
129
+ for chunk_metadata, output in zip(chunks_metadata, outputs):
130
+ duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
131
+ segment_size = int(ceil(duration) * self.model.frames_per_second)
132
+ segment_sizes.append(segment_size)
133
+ (
134
+ subsegments,
135
+ seek,
136
+ single_timestamp_ending,
137
+ ) = self.model._split_segments_by_timestamps(
138
+ tokenizer=tokenizer,
139
+ tokens=output["tokens"],
140
+ time_offset=chunk_metadata["start_time"],
141
+ segment_size=segment_size,
142
+ segment_duration=duration,
143
+ seek=0,
144
+ )
145
+ segmented_outputs.append(
146
+ [
147
+ dict(
148
+ text=tokenizer.decode(subsegment["tokens"]),
149
+ avg_logprob=output["avg_logprob"],
150
+ no_speech_prob=output["no_speech_prob"],
151
+ tokens=subsegment["tokens"],
152
+ start=subsegment["start"],
153
+ end=subsegment["end"],
154
+ compression_ratio=get_compression_ratio(
155
+ tokenizer.decode(subsegment["tokens"])
156
+ ),
157
+ seek=int(
158
+ chunk_metadata["start_time"] * self.model.frames_per_second
159
+ ),
160
+ )
161
+ for subsegment in subsegments
162
+ ]
163
+ )
164
+ if options.word_timestamps:
165
+ self.last_speech_timestamp = self.model.add_word_timestamps(
166
+ segmented_outputs,
167
+ tokenizer,
168
+ encoder_output,
169
+ segment_sizes,
170
+ options.prepend_punctuations,
171
+ options.append_punctuations,
172
+ self.last_speech_timestamp,
173
+ )
174
+
175
+ return segmented_outputs
176
+
177
+ def generate_segment_batched(
178
+ self,
179
+ features: np.ndarray,
180
+ tokenizer: Tokenizer,
181
+ options: TranscriptionOptions,
182
+ ):
183
+ batch_size = features.shape[0]
184
+
185
+ prompt = self.model.get_prompt(
186
+ tokenizer,
187
+ previous_tokens=(
188
+ tokenizer.encode(options.initial_prompt)
189
+ if options.initial_prompt is not None
190
+ else []
191
+ ),
192
+ without_timestamps=options.without_timestamps,
193
+ hotwords=options.hotwords,
194
+ )
195
+
196
+ if options.max_new_tokens is not None:
197
+ max_length = len(prompt) + options.max_new_tokens
198
+ else:
199
+ max_length = self.model.max_length
200
+
201
+ if max_length > self.model.max_length:
202
+ raise ValueError(
203
+ f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
204
+ f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
205
+ f"and `max_new_tokens` is: {max_length}. This exceeds the "
206
+ f"`max_length` of the Whisper model: {self.model.max_length}. "
207
+ "You should either reduce the length of your prompt, or "
208
+ "reduce the value of `max_new_tokens`, "
209
+ f"so that their combined length is less that {self.model.max_length}."
210
+ )
211
+
212
+ encoder_output = self.model.encode(features)
213
+ prompts = [prompt.copy() for _ in range(batch_size)]
214
+
215
+ if options.multilingual:
216
+ language_tokens = [
217
+ tokenizer.tokenizer.token_to_id(segment_langs[0][0])
218
+ for segment_langs in self.model.model.detect_language(encoder_output)
219
+ ]
220
+ language_token_index = prompt.index(tokenizer.language)
221
+
222
+ for i, language_token in enumerate(language_tokens):
223
+ prompts[i][language_token_index] = language_token
224
+
225
+ results = self.model.model.generate(
226
+ encoder_output,
227
+ prompts,
228
+ beam_size=options.beam_size,
229
+ patience=options.patience,
230
+ length_penalty=options.length_penalty,
231
+ max_length=max_length,
232
+ suppress_blank=options.suppress_blank,
233
+ suppress_tokens=options.suppress_tokens,
234
+ return_scores=True,
235
+ return_no_speech_prob=True,
236
+ sampling_temperature=options.temperatures[0],
237
+ repetition_penalty=options.repetition_penalty,
238
+ no_repeat_ngram_size=options.no_repeat_ngram_size,
239
+ )
240
+
241
+ output = []
242
+ for result in results:
243
+ # return scores
244
+ seq_len = len(result.sequences_ids[0])
245
+ cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
246
+
247
+ output.append(
248
+ dict(
249
+ avg_logprob=cum_logprob / (seq_len + 1),
250
+ no_speech_prob=result.no_speech_prob,
251
+ tokens=result.sequences_ids[0],
252
+ )
253
+ )
254
+
255
+ return encoder_output, output
256
+
257
+ def transcribe(
258
+ self,
259
+ audio: Union[str, BinaryIO, np.ndarray],
260
+ language: Optional[str] = None,
261
+ task: str = "transcribe",
262
+ log_progress: bool = False,
263
+ beam_size: int = 5,
264
+ best_of: int = 5,
265
+ patience: float = 1,
266
+ length_penalty: float = 1,
267
+ repetition_penalty: float = 1,
268
+ no_repeat_ngram_size: int = 0,
269
+ temperature: Union[float, List[float], Tuple[float, ...]] = [
270
+ 0.0,
271
+ 0.2,
272
+ 0.4,
273
+ 0.6,
274
+ 0.8,
275
+ 1.0,
276
+ ],
277
+ compression_ratio_threshold: Optional[float] = 2.4,
278
+ log_prob_threshold: Optional[float] = -1.0,
279
+ no_speech_threshold: Optional[float] = 0.6,
280
+ condition_on_previous_text: bool = True,
281
+ prompt_reset_on_temperature: float = 0.5,
282
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
283
+ prefix: Optional[str] = None,
284
+ suppress_blank: bool = True,
285
+ suppress_tokens: Optional[List[int]] = [-1],
286
+ without_timestamps: bool = True,
287
+ max_initial_timestamp: float = 1.0,
288
+ word_timestamps: bool = False,
289
+ prepend_punctuations: str = "\"'“¿([{-",
290
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
291
+ multilingual: bool = False,
292
+ vad_filter: bool = True,
293
+ vad_parameters: Optional[Union[dict, VadOptions]] = None,
294
+ max_new_tokens: Optional[int] = None,
295
+ chunk_length: Optional[int] = None,
296
+ clip_timestamps: Optional[List[dict]] = None,
297
+ hallucination_silence_threshold: Optional[float] = None,
298
+ batch_size: int = 8,
299
+ hotwords: Optional[str] = None,
300
+ language_detection_threshold: Optional[float] = 0.5,
301
+ language_detection_segments: int = 1,
302
+ ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
303
+ """transcribe audio in chunks in batched fashion and return with language info.
304
+
305
+ Arguments:
306
+ audio: Path to the input file (or a file-like object), or the audio waveform.
307
+ language: The language spoken in the audio. It should be a language code such
308
+ as "en" or "fr". If not set, the language will be detected in the first 30 seconds
309
+ of audio.
310
+ task: Task to execute (transcribe or translate).
311
+ log_progress: whether to show progress bar or not.
312
+ beam_size: Beam size to use for decoding.
313
+ best_of: Number of candidates when sampling with non-zero temperature.
314
+ patience: Beam search patience factor.
315
+ length_penalty: Exponential length penalty constant.
316
+ repetition_penalty: Penalty applied to the score of previously generated tokens
317
+ (set > 1 to penalize).
318
+ no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
319
+ temperature: Temperature for sampling. If a list or tuple is passed,
320
+ only the first value is used.
321
+ initial_prompt: Optional text string or iterable of token ids to provide as a
322
+ prompt for the each window.
323
+ suppress_blank: Suppress blank outputs at the beginning of the sampling.
324
+ suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
325
+ of symbols as defined in `tokenizer.non_speech_tokens()`.
326
+ without_timestamps: Only sample text tokens.
327
+ word_timestamps: Extract word-level timestamps using the cross-attention pattern
328
+ and dynamic time warping, and include the timestamps for each word in each segment.
329
+ Set as False.
330
+ prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
331
+ with the next word
332
+ append_punctuations: If word_timestamps is True, merge these punctuation symbols
333
+ with the previous word
334
+ multilingual: Perform language detection on every segment.
335
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
336
+ without speech. This step is using the Silero VAD model
337
+ https://github.com/snakers4/silero-vad.
338
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
339
+ parameters and default values in the class `VadOptions`).
340
+ max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
341
+ the maximum will be set by the default max_length.
342
+ chunk_length: The length of audio segments. If it is not None, it will overwrite the
343
+ default chunk_length of the FeatureExtractor.
344
+ clip_timestamps: Optionally provide list of dictionaries each containing "start" and
345
+ "end" keys that specify the start and end of the voiced region within
346
+ `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
347
+ batch_size: the maximum number of parallel requests to model for decoding.
348
+ hotwords:
349
+ Hotwords/hint phrases to the model. Has no effect if prefix is not None.
350
+ language_detection_threshold: If the maximum probability of the language tokens is
351
+ higher than this value, the language is detected.
352
+ language_detection_segments: Number of segments to consider for the language detection.
353
+
354
+ Unused Arguments
355
+ compression_ratio_threshold: If the gzip compression ratio is above this value,
356
+ treat as failed.
357
+ log_prob_threshold: If the average log probability over sampled tokens is
358
+ below this value, treat as failed.
359
+ no_speech_threshold: If the no_speech probability is higher than this value AND
360
+ the average log probability over sampled tokens is below `log_prob_threshold`,
361
+ consider the segment as silent.
362
+ condition_on_previous_text: If True, the previous output of the model is provided
363
+ as a prompt for the next window; disabling may make the text inconsistent across
364
+ windows, but the model becomes less prone to getting stuck in a failure loop,
365
+ such as repetition looping or timestamps going out of sync. Set as False
366
+ prompt_reset_on_temperature: Resets prompt if temperature is above this value.
367
+ Arg has effect only if condition_on_previous_text is True. Set at 0.5
368
+ prefix: Optional text to provide as a prefix at the beginning of each window.
369
+ max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
370
+ hallucination_silence_threshold: Optional[float]
371
+ When word_timestamps is True, skip silent periods longer than this threshold
372
+ (in seconds) when a possible hallucination is detected. set as None.
373
+ Returns:
374
+ A tuple with:
375
+
376
+ - a generator over transcribed segments
377
+ - an instance of TranscriptionInfo
378
+ """
379
+
380
+ sampling_rate = self.model.feature_extractor.sampling_rate
381
+
382
+ if multilingual and not self.model.model.is_multilingual:
383
+ self.model.logger.warning(
384
+ "The current model is English-only but the multilingual parameter is set to"
385
+ "True; setting to False instead."
386
+ )
387
+ multilingual = False
388
+
389
+ if not isinstance(audio, np.ndarray):
390
+ audio = decode_audio(audio, sampling_rate=sampling_rate)
391
+ duration = audio.shape[0] / sampling_rate
392
+
393
+ chunk_length = chunk_length or self.model.feature_extractor.chunk_length
394
+ # if no segment split is provided, use vad_model and generate segments
395
+ if not clip_timestamps:
396
+ if vad_filter:
397
+ if vad_parameters is None:
398
+ vad_parameters = VadOptions(
399
+ max_speech_duration_s=chunk_length,
400
+ min_silence_duration_ms=160,
401
+ )
402
+ elif isinstance(vad_parameters, dict):
403
+ if "max_speech_duration_s" in vad_parameters.keys():
404
+ vad_parameters.pop("max_speech_duration_s")
405
+
406
+ vad_parameters = VadOptions(
407
+ **vad_parameters, max_speech_duration_s=chunk_length
408
+ )
409
+
410
+ active_segments = get_speech_timestamps(audio, vad_parameters)
411
+ clip_timestamps = merge_segments(active_segments, vad_parameters)
412
+ # run the audio if it is less than 30 sec even without clip_timestamps
413
+ elif duration < chunk_length:
414
+ clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
415
+ else:
416
+ raise RuntimeError(
417
+ "No clip timestamps found. "
418
+ "Set 'vad_filter' to True or provide 'clip_timestamps'."
419
+ )
420
+
421
+ duration_after_vad = (
422
+ sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
423
+ / sampling_rate
424
+ )
425
+
426
+ audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
427
+ features = (
428
+ [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
429
+ if duration_after_vad
430
+ else []
431
+ )
432
+
433
+ all_language_probs = None
434
+ # detecting the language if not provided
435
+ if language is None:
436
+ if not self.model.model.is_multilingual:
437
+ language = "en"
438
+ language_probability = 1
439
+ else:
440
+ (
441
+ language,
442
+ language_probability,
443
+ all_language_probs,
444
+ ) = self.model.detect_language(
445
+ features=np.concatenate(
446
+ features
447
+ + [
448
+ np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
449
+ ],
450
+ axis=1,
451
+ ), # add a dummy feature to account for empty audio
452
+ language_detection_segments=language_detection_segments,
453
+ language_detection_threshold=language_detection_threshold,
454
+ )
455
+
456
+ self.model.logger.info(
457
+ "Detected language '%s' with probability %.2f",
458
+ language,
459
+ language_probability,
460
+ )
461
+ else:
462
+ if not self.model.model.is_multilingual and language != "en":
463
+ self.model.logger.warning(
464
+ "The current model is English-only but the language parameter is set to '%s'; "
465
+ "using 'en' instead." % language
466
+ )
467
+ language = "en"
468
+
469
+ language_probability = 1
470
+
471
+ tokenizer = Tokenizer(
472
+ self.model.hf_tokenizer,
473
+ self.model.model.is_multilingual,
474
+ task=task,
475
+ language=language,
476
+ )
477
+
478
+ features = (
479
+ np.stack([pad_or_trim(feature) for feature in features]) if features else []
480
+ )
481
+
482
+ options = TranscriptionOptions(
483
+ beam_size=beam_size,
484
+ best_of=best_of,
485
+ patience=patience,
486
+ length_penalty=length_penalty,
487
+ repetition_penalty=repetition_penalty,
488
+ no_repeat_ngram_size=no_repeat_ngram_size,
489
+ log_prob_threshold=log_prob_threshold,
490
+ no_speech_threshold=no_speech_threshold,
491
+ compression_ratio_threshold=compression_ratio_threshold,
492
+ temperatures=(
493
+ temperature[:1]
494
+ if isinstance(temperature, (list, tuple))
495
+ else [temperature]
496
+ ),
497
+ initial_prompt=initial_prompt,
498
+ prefix=prefix,
499
+ suppress_blank=suppress_blank,
500
+ suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
501
+ prepend_punctuations=prepend_punctuations,
502
+ append_punctuations=append_punctuations,
503
+ max_new_tokens=max_new_tokens,
504
+ hotwords=hotwords,
505
+ word_timestamps=word_timestamps,
506
+ hallucination_silence_threshold=None,
507
+ condition_on_previous_text=False,
508
+ clip_timestamps=clip_timestamps,
509
+ prompt_reset_on_temperature=0.5,
510
+ multilingual=multilingual,
511
+ without_timestamps=without_timestamps,
512
+ max_initial_timestamp=0.0,
513
+ )
514
+
515
+ info = TranscriptionInfo(
516
+ language=language,
517
+ language_probability=language_probability,
518
+ duration=duration,
519
+ duration_after_vad=duration_after_vad,
520
+ transcription_options=options,
521
+ vad_options=vad_parameters,
522
+ all_language_probs=all_language_probs,
523
+ )
524
+
525
+ segments = self._batched_segments_generator(
526
+ features,
527
+ tokenizer,
528
+ chunks_metadata,
529
+ batch_size,
530
+ options,
531
+ log_progress,
532
+ )
533
+
534
+ return segments, info
535
+
536
+ def _batched_segments_generator(
537
+ self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
538
+ ):
539
+ pbar = tqdm(total=len(features), disable=not log_progress, position=0)
540
+ seg_idx = 0
541
+ for i in range(0, len(features), batch_size):
542
+ results = self.forward(
543
+ features[i : i + batch_size],
544
+ tokenizer,
545
+ chunks_metadata[i : i + batch_size],
546
+ options,
547
+ )
548
+
549
+ for result in results:
550
+ for segment in result:
551
+ seg_idx += 1
552
+ yield Segment(
553
+ seek=segment["seek"],
554
+ id=seg_idx,
555
+ text=segment["text"],
556
+ start=round(segment["start"], 3),
557
+ end=round(segment["end"], 3),
558
+ words=(
559
+ None
560
+ if not options.word_timestamps
561
+ else [Word(**word) for word in segment["words"]]
562
+ ),
563
+ tokens=segment["tokens"],
564
+ avg_logprob=segment["avg_logprob"],
565
+ no_speech_prob=segment["no_speech_prob"],
566
+ compression_ratio=segment["compression_ratio"],
567
+ temperature=options.temperatures[0],
568
+ )
569
+
570
+ pbar.update(1)
571
+
572
+ pbar.close()
573
+ self.last_speech_timestamp = 0.0
574
+
575
+
576
+ class WhisperModel:
577
+ def __init__(
578
+ self,
579
+ model_size_or_path: str,
580
+ device: str = "auto",
581
+ device_index: Union[int, List[int]] = 0,
582
+ compute_type: str = "default",
583
+ cpu_threads: int = 0,
584
+ num_workers: int = 1,
585
+ download_root: Optional[str] = None,
586
+ local_files_only: bool = False,
587
+ files: dict = None,
588
+ **model_kwargs,
589
+ ):
590
+ """Initializes the Whisper model.
591
+
592
+ Args:
593
+ model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
594
+ small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
595
+ large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo),
596
+ a path to a converted model directory, or a CTranslate2-converted Whisper model ID from
597
+ the HF Hub. When a size or a model ID is configured, the converted model is downloaded
598
+ from the Hugging Face Hub.
599
+ device: Device to use for computation ("cpu", "cuda", "auto").
600
+ device_index: Device ID to use.
601
+ The model can also be loaded on multiple GPUs by passing a list of IDs
602
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
603
+ when transcribe() is called from multiple Python threads (see also num_workers).
604
+ compute_type: Type to use for computation.
605
+ See https://opennmt.net/CTranslate2/quantization.html.
606
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
607
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
608
+ num_workers: When transcribe() is called from multiple Python threads,
609
+ having multiple workers enables true parallelism when running the model
610
+ (concurrent calls to self.model.generate() will run in parallel).
611
+ This can improve the global throughput at the cost of increased memory usage.
612
+ download_root: Directory where the models should be saved. If not set, the models
613
+ are saved in the standard Hugging Face cache directory.
614
+ local_files_only: If True, avoid downloading the file and return the path to the
615
+ local cached file if it exists.
616
+ files: Load model files from the memory. This argument is a dictionary mapping file names
617
+ to file contents as file-like or bytes objects. If this is set, model_path acts as an
618
+ identifier for this model.
619
+ """
620
+ self.logger = get_logger()
621
+
622
+ tokenizer_bytes, preprocessor_bytes = None, None
623
+ if files:
624
+ model_path = model_size_or_path
625
+ tokenizer_bytes = files.pop("tokenizer.json", None)
626
+ preprocessor_bytes = files.pop("preprocessor_config.json", None)
627
+ elif os.path.isdir(model_size_or_path):
628
+ model_path = model_size_or_path
629
+ else:
630
+ model_path = download_model(
631
+ model_size_or_path,
632
+ local_files_only=local_files_only,
633
+ cache_dir=download_root,
634
+ )
635
+
636
+ self.model = ctranslate2.models.Whisper(
637
+ model_path,
638
+ device=device,
639
+ device_index=device_index,
640
+ compute_type=compute_type,
641
+ intra_threads=cpu_threads,
642
+ inter_threads=num_workers,
643
+ files=files,
644
+ **model_kwargs,
645
+ )
646
+
647
+ tokenizer_file = os.path.join(model_path, "tokenizer.json")
648
+ if tokenizer_bytes:
649
+ self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
650
+ elif os.path.isfile(tokenizer_file):
651
+ self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
652
+ else:
653
+ self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
654
+ "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
655
+ )
656
+ self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
657
+ self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
658
+ self.input_stride = 2
659
+ self.num_samples_per_token = (
660
+ self.feature_extractor.hop_length * self.input_stride
661
+ )
662
+ self.frames_per_second = (
663
+ self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
664
+ )
665
+ self.tokens_per_second = (
666
+ self.feature_extractor.sampling_rate // self.num_samples_per_token
667
+ )
668
+ self.time_precision = 0.02
669
+ self.max_length = 448
670
+
671
+ @property
672
+ def supported_languages(self) -> List[str]:
673
+ """The languages supported by the model."""
674
+ return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
675
+
676
+ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
677
+ config = {}
678
+ try:
679
+ config_path = os.path.join(model_path, "preprocessor_config.json")
680
+ if preprocessor_bytes:
681
+ config = json.loads(preprocessor_bytes)
682
+ elif os.path.isfile(config_path):
683
+ with open(config_path, "r", encoding="utf-8") as file:
684
+ config = json.load(file)
685
+ else:
686
+ return config
687
+ valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
688
+ return {k: v for k, v in config.items() if k in valid_keys}
689
+ except json.JSONDecodeError as e:
690
+ self.logger.warning("Could not load preprocessor config: %s", e)
691
+
692
+ return config
693
+
694
+ def transcribe(
695
+ self,
696
+ audio: Union[str, BinaryIO, np.ndarray],
697
+ language: Optional[str] = None,
698
+ task: str = "transcribe",
699
+ log_progress: bool = False,
700
+ beam_size: int = 5,
701
+ best_of: int = 5,
702
+ patience: float = 1,
703
+ length_penalty: float = 1,
704
+ repetition_penalty: float = 1,
705
+ no_repeat_ngram_size: int = 0,
706
+ temperature: Union[float, List[float], Tuple[float, ...]] = [
707
+ 0.0,
708
+ 0.2,
709
+ 0.4,
710
+ 0.6,
711
+ 0.8,
712
+ 1.0,
713
+ ],
714
+ compression_ratio_threshold: Optional[float] = 2.4,
715
+ log_prob_threshold: Optional[float] = -1.0,
716
+ no_speech_threshold: Optional[float] = 0.6,
717
+ condition_on_previous_text: bool = True,
718
+ prompt_reset_on_temperature: float = 0.5,
719
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
720
+ prefix: Optional[str] = None,
721
+ suppress_blank: bool = True,
722
+ suppress_tokens: Optional[List[int]] = [-1],
723
+ without_timestamps: bool = False,
724
+ max_initial_timestamp: float = 1.0,
725
+ word_timestamps: bool = False,
726
+ prepend_punctuations: str = "\"'“¿([{-",
727
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
728
+ multilingual: bool = False,
729
+ vad_filter: bool = False,
730
+ vad_parameters: Optional[Union[dict, VadOptions]] = None,
731
+ max_new_tokens: Optional[int] = None,
732
+ chunk_length: Optional[int] = None,
733
+ clip_timestamps: Union[str, List[float]] = "0",
734
+ hallucination_silence_threshold: Optional[float] = None,
735
+ hotwords: Optional[str] = None,
736
+ language_detection_threshold: Optional[float] = 0.5,
737
+ language_detection_segments: int = 1,
738
+ ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
739
+ """Transcribes an input file.
740
+
741
+ Arguments:
742
+ audio: Path to the input file (or a file-like object), or the audio waveform.
743
+ language: The language spoken in the audio. It should be a language code such
744
+ as "en" or "fr". If not set, the language will be detected in the first 30 seconds
745
+ of audio.
746
+ task: Task to execute (transcribe or translate).
747
+ log_progress: whether to show progress bar or not.
748
+ beam_size: Beam size to use for decoding.
749
+ best_of: Number of candidates when sampling with non-zero temperature.
750
+ patience: Beam search patience factor.
751
+ length_penalty: Exponential length penalty constant.
752
+ repetition_penalty: Penalty applied to the score of previously generated tokens
753
+ (set > 1 to penalize).
754
+ no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
755
+ temperature: Temperature for sampling. It can be a tuple of temperatures,
756
+ which will be successively used upon failures according to either
757
+ `compression_ratio_threshold` or `log_prob_threshold`.
758
+ compression_ratio_threshold: If the gzip compression ratio is above this value,
759
+ treat as failed.
760
+ log_prob_threshold: If the average log probability over sampled tokens is
761
+ below this value, treat as failed.
762
+ no_speech_threshold: If the no_speech probability is higher than this value AND
763
+ the average log probability over sampled tokens is below `log_prob_threshold`,
764
+ consider the segment as silent.
765
+ condition_on_previous_text: If True, the previous output of the model is provided
766
+ as a prompt for the next window; disabling may make the text inconsistent across
767
+ windows, but the model becomes less prone to getting stuck in a failure loop,
768
+ such as repetition looping or timestamps going out of sync.
769
+ prompt_reset_on_temperature: Resets prompt if temperature is above this value.
770
+ Arg has effect only if condition_on_previous_text is True.
771
+ initial_prompt: Optional text string or iterable of token ids to provide as a
772
+ prompt for the first window.
773
+ prefix: Optional text to provide as a prefix for the first window.
774
+ suppress_blank: Suppress blank outputs at the beginning of the sampling.
775
+ suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
776
+ of symbols as defined in `tokenizer.non_speech_tokens()`.
777
+ without_timestamps: Only sample text tokens.
778
+ max_initial_timestamp: The initial timestamp cannot be later than this.
779
+ word_timestamps: Extract word-level timestamps using the cross-attention pattern
780
+ and dynamic time warping, and include the timestamps for each word in each segment.
781
+ prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
782
+ with the next word
783
+ append_punctuations: If word_timestamps is True, merge these punctuation symbols
784
+ with the previous word
785
+ multilingual: Perform language detection on every segment.
786
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
787
+ without speech. This step is using the Silero VAD model
788
+ https://github.com/snakers4/silero-vad.
789
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
790
+ parameters and default values in the class `VadOptions`).
791
+ max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
792
+ the maximum will be set by the default max_length.
793
+ chunk_length: The length of audio segments. If it is not None, it will overwrite the
794
+ default chunk_length of the FeatureExtractor.
795
+ clip_timestamps:
796
+ Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
797
+ process. The last end timestamp defaults to the end of the file.
798
+ vad_filter will be ignored if clip_timestamps is used.
799
+ hallucination_silence_threshold:
800
+ When word_timestamps is True, skip silent periods longer than this threshold
801
+ (in seconds) when a possible hallucination is detected
802
+ hotwords:
803
+ Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
804
+ language_detection_threshold: If the maximum probability of the language tokens is higher
805
+ than this value, the language is detected.
806
+ language_detection_segments: Number of segments to consider for the language detection.
807
+ Returns:
808
+ A tuple with:
809
+
810
+ - a generator over transcribed segments
811
+ - an instance of TranscriptionInfo
812
+ """
813
+ sampling_rate = self.feature_extractor.sampling_rate
814
+
815
+ if multilingual and not self.model.is_multilingual:
816
+ self.logger.warning(
817
+ "The current model is English-only but the multilingual parameter is set to"
818
+ "True; setting to False instead."
819
+ )
820
+ multilingual = False
821
+
822
+ if not isinstance(audio, np.ndarray):
823
+ audio = decode_audio(audio, sampling_rate=sampling_rate)
824
+
825
+ duration = audio.shape[0] / sampling_rate
826
+ duration_after_vad = duration
827
+
828
+ self.logger.info(
829
+ "Processing audio with duration %s", format_timestamp(duration)
830
+ )
831
+
832
+ if vad_filter and clip_timestamps == "0":
833
+ if vad_parameters is None:
834
+ vad_parameters = VadOptions()
835
+ elif isinstance(vad_parameters, dict):
836
+ vad_parameters = VadOptions(**vad_parameters)
837
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
838
+ audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
839
+ audio = np.concatenate(audio_chunks, axis=0)
840
+ duration_after_vad = audio.shape[0] / sampling_rate
841
+
842
+ self.logger.info(
843
+ "VAD filter removed %s of audio",
844
+ format_timestamp(duration - duration_after_vad),
845
+ )
846
+
847
+ if self.logger.isEnabledFor(logging.DEBUG):
848
+ self.logger.debug(
849
+ "VAD filter kept the following audio segments: %s",
850
+ ", ".join(
851
+ "[%s -> %s]"
852
+ % (
853
+ format_timestamp(chunk["start"] / sampling_rate),
854
+ format_timestamp(chunk["end"] / sampling_rate),
855
+ )
856
+ for chunk in speech_chunks
857
+ ),
858
+ )
859
+
860
+ else:
861
+ speech_chunks = None
862
+ if audio.shape[0] == 0:
863
+ return None, None
864
+ features = self.feature_extractor(audio, chunk_length=chunk_length)
865
+
866
+ encoder_output = None
867
+ all_language_probs = None
868
+
869
+ # detecting the language if not provided
870
+ if language is None:
871
+ if not self.model.is_multilingual:
872
+ language = "en"
873
+ language_probability = 1
874
+ else:
875
+ start_timestamp = (
876
+ float(clip_timestamps.split(",")[0])
877
+ if isinstance(clip_timestamps, str)
878
+ else clip_timestamps[0]
879
+ )
880
+ content_frames = features.shape[-1] - 1
881
+ seek = (
882
+ int(start_timestamp * self.frames_per_second)
883
+ if start_timestamp * self.frames_per_second < content_frames
884
+ else 0
885
+ )
886
+ (
887
+ language,
888
+ language_probability,
889
+ all_language_probs,
890
+ ) = self.detect_language(
891
+ features=features[..., seek:],
892
+ language_detection_segments=language_detection_segments,
893
+ language_detection_threshold=language_detection_threshold,
894
+ )
895
+
896
+ self.logger.info(
897
+ "Detected language '%s' with probability %.2f",
898
+ language,
899
+ language_probability,
900
+ )
901
+ else:
902
+ if not self.model.is_multilingual and language != "en":
903
+ self.logger.warning(
904
+ "The current model is English-only but the language parameter is set to '%s'; "
905
+ "using 'en' instead." % language
906
+ )
907
+ language = "en"
908
+
909
+ language_probability = 1
910
+
911
+ tokenizer = Tokenizer(
912
+ self.hf_tokenizer,
913
+ self.model.is_multilingual,
914
+ task=task,
915
+ language=language,
916
+ )
917
+
918
+ options = TranscriptionOptions(
919
+ beam_size=beam_size,
920
+ best_of=best_of,
921
+ patience=patience,
922
+ length_penalty=length_penalty,
923
+ repetition_penalty=repetition_penalty,
924
+ no_repeat_ngram_size=no_repeat_ngram_size,
925
+ log_prob_threshold=log_prob_threshold,
926
+ no_speech_threshold=no_speech_threshold,
927
+ compression_ratio_threshold=compression_ratio_threshold,
928
+ condition_on_previous_text=condition_on_previous_text,
929
+ prompt_reset_on_temperature=prompt_reset_on_temperature,
930
+ temperatures=(
931
+ temperature if isinstance(temperature, (list, tuple)) else [temperature]
932
+ ),
933
+ initial_prompt=initial_prompt,
934
+ prefix=prefix,
935
+ suppress_blank=suppress_blank,
936
+ suppress_tokens=(
937
+ get_suppressed_tokens(tokenizer, suppress_tokens)
938
+ if suppress_tokens
939
+ else suppress_tokens
940
+ ),
941
+ without_timestamps=without_timestamps,
942
+ max_initial_timestamp=max_initial_timestamp,
943
+ word_timestamps=word_timestamps,
944
+ prepend_punctuations=prepend_punctuations,
945
+ append_punctuations=append_punctuations,
946
+ multilingual=multilingual,
947
+ max_new_tokens=max_new_tokens,
948
+ clip_timestamps=clip_timestamps,
949
+ hallucination_silence_threshold=hallucination_silence_threshold,
950
+ hotwords=hotwords,
951
+ )
952
+
953
+ segments = self.generate_segments(
954
+ features, tokenizer, options, log_progress, encoder_output
955
+ )
956
+
957
+ if speech_chunks:
958
+ segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
959
+
960
+ info = TranscriptionInfo(
961
+ language=language,
962
+ language_probability=language_probability,
963
+ duration=duration,
964
+ duration_after_vad=duration_after_vad,
965
+ transcription_options=options,
966
+ vad_options=vad_parameters,
967
+ all_language_probs=all_language_probs,
968
+ )
969
+
970
+ return segments, info
971
+
972
+ def _split_segments_by_timestamps(
973
+ self,
974
+ tokenizer: Tokenizer,
975
+ tokens: List[int],
976
+ time_offset: float,
977
+ segment_size: int,
978
+ segment_duration: float,
979
+ seek: int,
980
+ ) -> List[List[int]]:
981
+ current_segments = []
982
+ single_timestamp_ending = (
983
+ len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
984
+ )
985
+
986
+ consecutive_timestamps = [
987
+ i
988
+ for i in range(len(tokens))
989
+ if i > 0
990
+ and tokens[i] >= tokenizer.timestamp_begin
991
+ and tokens[i - 1] >= tokenizer.timestamp_begin
992
+ ]
993
+
994
+ if len(consecutive_timestamps) > 0:
995
+ slices = list(consecutive_timestamps)
996
+ if single_timestamp_ending:
997
+ slices.append(len(tokens))
998
+
999
+ last_slice = 0
1000
+ for current_slice in slices:
1001
+ sliced_tokens = tokens[last_slice:current_slice]
1002
+ start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin
1003
+ end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin
1004
+ start_time = (
1005
+ time_offset + start_timestamp_position * self.time_precision
1006
+ )
1007
+ end_time = time_offset + end_timestamp_position * self.time_precision
1008
+
1009
+ current_segments.append(
1010
+ dict(
1011
+ seek=seek,
1012
+ start=start_time,
1013
+ end=end_time,
1014
+ tokens=sliced_tokens,
1015
+ )
1016
+ )
1017
+ last_slice = current_slice
1018
+
1019
+ if single_timestamp_ending:
1020
+ # single timestamp at the end means no speech after the last timestamp.
1021
+ seek += segment_size
1022
+ else:
1023
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1024
+ last_timestamp_position = (
1025
+ tokens[last_slice - 1] - tokenizer.timestamp_begin
1026
+ )
1027
+ seek += last_timestamp_position * self.input_stride
1028
+
1029
+ else:
1030
+ duration = segment_duration
1031
+ timestamps = [
1032
+ token for token in tokens if token >= tokenizer.timestamp_begin
1033
+ ]
1034
+ if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
1035
+ last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
1036
+ duration = last_timestamp_position * self.time_precision
1037
+
1038
+ current_segments.append(
1039
+ dict(
1040
+ seek=seek,
1041
+ start=time_offset,
1042
+ end=time_offset + duration,
1043
+ tokens=tokens,
1044
+ )
1045
+ )
1046
+
1047
+ seek += segment_size
1048
+
1049
+ return current_segments, seek, single_timestamp_ending
1050
+
1051
+ def generate_segments(
1052
+ self,
1053
+ features: np.ndarray,
1054
+ tokenizer: Tokenizer,
1055
+ options: TranscriptionOptions,
1056
+ log_progress,
1057
+ encoder_output: Optional[ctranslate2.StorageView] = None,
1058
+ ) -> Iterable[Segment]:
1059
+ content_frames = features.shape[-1] - 1
1060
+ content_duration = float(content_frames * self.feature_extractor.time_per_frame)
1061
+
1062
+ if isinstance(options.clip_timestamps, str):
1063
+ options.clip_timestamps = [
1064
+ float(ts)
1065
+ for ts in (
1066
+ options.clip_timestamps.split(",")
1067
+ if options.clip_timestamps
1068
+ else []
1069
+ )
1070
+ ]
1071
+
1072
+ seek_points: List[int] = [
1073
+ round(ts * self.frames_per_second) for ts in options.clip_timestamps
1074
+ ]
1075
+ if len(seek_points) == 0:
1076
+ seek_points.append(0)
1077
+ if len(seek_points) % 2 == 1:
1078
+ seek_points.append(content_frames)
1079
+ seek_clips: List[Tuple[int, int]] = list(
1080
+ zip(seek_points[::2], seek_points[1::2])
1081
+ )
1082
+
1083
+ punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
1084
+
1085
+ idx = 0
1086
+ clip_idx = 0
1087
+ seek = seek_clips[clip_idx][0]
1088
+ all_tokens = []
1089
+ prompt_reset_since = 0
1090
+
1091
+ if options.initial_prompt is not None:
1092
+ if isinstance(options.initial_prompt, str):
1093
+ initial_prompt = " " + options.initial_prompt.strip()
1094
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
1095
+ all_tokens.extend(initial_prompt_tokens)
1096
+ else:
1097
+ all_tokens.extend(options.initial_prompt)
1098
+
1099
+ pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
1100
+ last_speech_timestamp = 0.0
1101
+ all_segments = []
1102
+ # NOTE: This loop is obscurely flattened to make the diff readable.
1103
+ # A later commit should turn this into a simpler nested loop.
1104
+ # for seek_clip_start, seek_clip_end in seek_clips:
1105
+ # while seek < seek_clip_end
1106
+ while clip_idx < len(seek_clips):
1107
+ seek_clip_start, seek_clip_end = seek_clips[clip_idx]
1108
+ if seek_clip_end > content_frames:
1109
+ seek_clip_end = content_frames
1110
+ if seek < seek_clip_start:
1111
+ seek = seek_clip_start
1112
+ if seek >= seek_clip_end:
1113
+ clip_idx += 1
1114
+ if clip_idx < len(seek_clips):
1115
+ seek = seek_clips[clip_idx][0]
1116
+ continue
1117
+ time_offset = seek * self.feature_extractor.time_per_frame
1118
+ window_end_time = float(
1119
+ (seek + self.feature_extractor.nb_max_frames)
1120
+ * self.feature_extractor.time_per_frame
1121
+ )
1122
+ segment_size = min(
1123
+ self.feature_extractor.nb_max_frames,
1124
+ content_frames - seek,
1125
+ seek_clip_end - seek,
1126
+ )
1127
+ segment = features[:, seek : seek + segment_size]
1128
+ segment_duration = segment_size * self.feature_extractor.time_per_frame
1129
+ segment = pad_or_trim(segment)
1130
+
1131
+ if self.logger.isEnabledFor(logging.DEBUG):
1132
+ self.logger.debug(
1133
+ "Processing segment at %s", format_timestamp(time_offset)
1134
+ )
1135
+
1136
+ previous_tokens = all_tokens[prompt_reset_since:]
1137
+
1138
+ if seek > 0 or encoder_output is None:
1139
+ encoder_output = self.encode(segment)
1140
+
1141
+ if options.multilingual:
1142
+ results = self.model.detect_language(encoder_output)
1143
+ language_token, language_probability = results[0][0]
1144
+ language = language_token[2:-2]
1145
+
1146
+ tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
1147
+ tokenizer.language_code = language
1148
+
1149
+ prompt = self.get_prompt(
1150
+ tokenizer,
1151
+ previous_tokens,
1152
+ without_timestamps=options.without_timestamps,
1153
+ prefix=options.prefix if seek == 0 else None,
1154
+ hotwords=options.hotwords,
1155
+ )
1156
+
1157
+ (
1158
+ result,
1159
+ avg_logprob,
1160
+ temperature,
1161
+ compression_ratio,
1162
+ ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
1163
+
1164
+ if options.no_speech_threshold is not None:
1165
+ # no voice activity check
1166
+ should_skip = result.no_speech_prob > options.no_speech_threshold
1167
+
1168
+ if (
1169
+ options.log_prob_threshold is not None
1170
+ and avg_logprob > options.log_prob_threshold
1171
+ ):
1172
+ # don't skip if the logprob is high enough, despite the no_speech_prob
1173
+ should_skip = False
1174
+
1175
+ if should_skip:
1176
+ self.logger.debug(
1177
+ "No speech threshold is met (%f > %f)",
1178
+ result.no_speech_prob,
1179
+ options.no_speech_threshold,
1180
+ )
1181
+
1182
+ # fast-forward to the next segment boundary
1183
+ seek += segment_size
1184
+ continue
1185
+
1186
+ tokens = result.sequences_ids[0]
1187
+
1188
+ previous_seek = seek
1189
+
1190
+ # anomalous words are very long/short/improbable
1191
+ def word_anomaly_score(word: dict) -> float:
1192
+ probability = word.get("probability", 0.0)
1193
+ duration = word["end"] - word["start"]
1194
+ score = 0.0
1195
+ if probability < 0.15:
1196
+ score += 1.0
1197
+ if duration < 0.133:
1198
+ score += (0.133 - duration) * 15
1199
+ if duration > 2.0:
1200
+ score += duration - 2.0
1201
+ return score
1202
+
1203
+ def is_segment_anomaly(segment: Optional[dict]) -> bool:
1204
+ if segment is None or not segment["words"]:
1205
+ return False
1206
+ words = [w for w in segment["words"] if w["word"] not in punctuation]
1207
+ words = words[:8]
1208
+ score = sum(word_anomaly_score(w) for w in words)
1209
+ return score >= 3 or score + 0.01 >= len(words)
1210
+
1211
+ def next_words_segment(segments: List[dict]) -> Optional[dict]:
1212
+ return next((s for s in segments if s["words"]), None)
1213
+
1214
+ (
1215
+ current_segments,
1216
+ seek,
1217
+ single_timestamp_ending,
1218
+ ) = self._split_segments_by_timestamps(
1219
+ tokenizer=tokenizer,
1220
+ tokens=tokens,
1221
+ time_offset=time_offset,
1222
+ segment_size=segment_size,
1223
+ segment_duration=segment_duration,
1224
+ seek=seek,
1225
+ )
1226
+
1227
+ if options.word_timestamps:
1228
+ self.add_word_timestamps(
1229
+ [current_segments],
1230
+ tokenizer,
1231
+ encoder_output,
1232
+ segment_size,
1233
+ options.prepend_punctuations,
1234
+ options.append_punctuations,
1235
+ last_speech_timestamp=last_speech_timestamp,
1236
+ )
1237
+ if not single_timestamp_ending:
1238
+ last_word_end = get_end(current_segments)
1239
+ if last_word_end is not None and last_word_end > time_offset:
1240
+ seek = round(last_word_end * self.frames_per_second)
1241
+
1242
+ # skip silence before possible hallucinations
1243
+ if options.hallucination_silence_threshold is not None:
1244
+ threshold = options.hallucination_silence_threshold
1245
+
1246
+ # if first segment might be a hallucination, skip leading silence
1247
+ first_segment = next_words_segment(current_segments)
1248
+ if first_segment is not None and is_segment_anomaly(first_segment):
1249
+ gap = first_segment["start"] - time_offset
1250
+ if gap > threshold:
1251
+ seek = previous_seek + round(gap * self.frames_per_second)
1252
+ continue
1253
+
1254
+ # skip silence before any possible hallucination that is surrounded
1255
+ # by silence or more hallucinations
1256
+ hal_last_end = last_speech_timestamp
1257
+ for si in range(len(current_segments)):
1258
+ segment = current_segments[si]
1259
+ if not segment["words"]:
1260
+ continue
1261
+ if is_segment_anomaly(segment):
1262
+ next_segment = next_words_segment(
1263
+ current_segments[si + 1 :]
1264
+ )
1265
+ if next_segment is not None:
1266
+ hal_next_start = next_segment["words"][0]["start"]
1267
+ else:
1268
+ hal_next_start = time_offset + segment_duration
1269
+ silence_before = (
1270
+ segment["start"] - hal_last_end > threshold
1271
+ or segment["start"] < threshold
1272
+ or segment["start"] - time_offset < 2.0
1273
+ )
1274
+ silence_after = (
1275
+ hal_next_start - segment["end"] > threshold
1276
+ or is_segment_anomaly(next_segment)
1277
+ or window_end_time - segment["end"] < 2.0
1278
+ )
1279
+ if silence_before and silence_after:
1280
+ seek = round(
1281
+ max(time_offset + 1, segment["start"])
1282
+ * self.frames_per_second
1283
+ )
1284
+ if content_duration - segment["end"] < threshold:
1285
+ seek = content_frames
1286
+ current_segments[si:] = []
1287
+ break
1288
+ hal_last_end = segment["end"]
1289
+
1290
+ last_word_end = get_end(current_segments)
1291
+ if last_word_end is not None:
1292
+ last_speech_timestamp = last_word_end
1293
+ for segment in current_segments:
1294
+ tokens = segment["tokens"]
1295
+ text = tokenizer.decode(tokens)
1296
+
1297
+ if segment["start"] == segment["end"] or not text.strip():
1298
+ continue
1299
+
1300
+ all_tokens.extend(tokens)
1301
+ idx += 1
1302
+
1303
+ all_segments.append(Segment(
1304
+ id=idx,
1305
+ seek=previous_seek,
1306
+ start=segment["start"],
1307
+ end=segment["end"],
1308
+ text=text,
1309
+ tokens=tokens,
1310
+ temperature=temperature,
1311
+ avg_logprob=avg_logprob,
1312
+ compression_ratio=compression_ratio,
1313
+ no_speech_prob=result.no_speech_prob,
1314
+ words=(
1315
+ [Word(**word) for word in segment["words"]]
1316
+ if options.word_timestamps
1317
+ else None
1318
+ ),
1319
+ ))
1320
+
1321
+ if (
1322
+ not options.condition_on_previous_text
1323
+ or temperature > options.prompt_reset_on_temperature
1324
+ ):
1325
+ if options.condition_on_previous_text:
1326
+ self.logger.debug(
1327
+ "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
1328
+ temperature,
1329
+ options.prompt_reset_on_temperature,
1330
+ )
1331
+
1332
+ prompt_reset_since = len(all_tokens)
1333
+
1334
+ pbar.update(
1335
+ (min(content_frames, seek) - previous_seek)
1336
+ * self.feature_extractor.time_per_frame,
1337
+ )
1338
+ pbar.close()
1339
+ return all_segments
1340
+
1341
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
1342
+ # When the model is running on multiple GPUs, the encoder output should be moved
1343
+ # to the CPU since we don't know which GPU will handle the next job.
1344
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
1345
+
1346
+ if features.ndim == 2:
1347
+ features = np.expand_dims(features, 0)
1348
+ features = get_ctranslate2_storage(features)
1349
+
1350
+ return self.model.encode(features, to_cpu=to_cpu)
1351
+
1352
+ def generate_with_fallback(
1353
+ self,
1354
+ encoder_output: ctranslate2.StorageView,
1355
+ prompt: List[int],
1356
+ tokenizer: Tokenizer,
1357
+ options: TranscriptionOptions,
1358
+ ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
1359
+ decode_result = None
1360
+ all_results = []
1361
+ below_cr_threshold_results = []
1362
+
1363
+ max_initial_timestamp_index = int(
1364
+ round(options.max_initial_timestamp / self.time_precision)
1365
+ )
1366
+ if options.max_new_tokens is not None:
1367
+ max_length = len(prompt) + options.max_new_tokens
1368
+ else:
1369
+ max_length = self.max_length
1370
+
1371
+ if max_length > self.max_length:
1372
+ raise ValueError(
1373
+ f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
1374
+ f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
1375
+ f"and `max_new_tokens` is: {max_length}. This exceeds the "
1376
+ f"`max_length` of the Whisper model: {self.max_length}. "
1377
+ "You should either reduce the length of your prompt, or "
1378
+ "reduce the value of `max_new_tokens`, "
1379
+ f"so that their combined length is less that {self.max_length}."
1380
+ )
1381
+
1382
+ for temperature in options.temperatures:
1383
+ if temperature > 0:
1384
+ kwargs = {
1385
+ "beam_size": 1,
1386
+ "num_hypotheses": options.best_of,
1387
+ "sampling_topk": 0,
1388
+ "sampling_temperature": temperature,
1389
+ }
1390
+ else:
1391
+ kwargs = {
1392
+ "beam_size": options.beam_size,
1393
+ "patience": options.patience,
1394
+ }
1395
+
1396
+ result = self.model.generate(
1397
+ encoder_output,
1398
+ [prompt],
1399
+ length_penalty=options.length_penalty,
1400
+ repetition_penalty=options.repetition_penalty,
1401
+ no_repeat_ngram_size=options.no_repeat_ngram_size,
1402
+ max_length=max_length,
1403
+ return_scores=True,
1404
+ return_no_speech_prob=True,
1405
+ suppress_blank=options.suppress_blank,
1406
+ suppress_tokens=options.suppress_tokens,
1407
+ max_initial_timestamp_index=max_initial_timestamp_index,
1408
+ **kwargs,
1409
+ )[0]
1410
+
1411
+ tokens = result.sequences_ids[0]
1412
+
1413
+ # Recover the average log prob from the returned score.
1414
+ seq_len = len(tokens)
1415
+ cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
1416
+ avg_logprob = cum_logprob / (seq_len + 1)
1417
+
1418
+ text = tokenizer.decode(tokens).strip()
1419
+ compression_ratio = get_compression_ratio(text)
1420
+
1421
+ decode_result = (
1422
+ result,
1423
+ avg_logprob,
1424
+ temperature,
1425
+ compression_ratio,
1426
+ )
1427
+ all_results.append(decode_result)
1428
+
1429
+ needs_fallback = False
1430
+
1431
+ if options.compression_ratio_threshold is not None:
1432
+ if compression_ratio > options.compression_ratio_threshold:
1433
+ needs_fallback = True # too repetitive
1434
+
1435
+ self.logger.debug(
1436
+ "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
1437
+ temperature,
1438
+ compression_ratio,
1439
+ options.compression_ratio_threshold,
1440
+ )
1441
+ else:
1442
+ below_cr_threshold_results.append(decode_result)
1443
+
1444
+ if (
1445
+ options.log_prob_threshold is not None
1446
+ and avg_logprob < options.log_prob_threshold
1447
+ ):
1448
+ needs_fallback = True # average log probability is too low
1449
+
1450
+ self.logger.debug(
1451
+ "Log probability threshold is not met with temperature %.1f (%f < %f)",
1452
+ temperature,
1453
+ avg_logprob,
1454
+ options.log_prob_threshold,
1455
+ )
1456
+
1457
+ if (
1458
+ options.no_speech_threshold is not None
1459
+ and result.no_speech_prob > options.no_speech_threshold
1460
+ and options.log_prob_threshold is not None
1461
+ and avg_logprob < options.log_prob_threshold
1462
+ ):
1463
+ needs_fallback = False # silence
1464
+
1465
+ if not needs_fallback:
1466
+ break
1467
+ else:
1468
+ # all failed, select the result with the highest average log probability
1469
+ decode_result = max(
1470
+ below_cr_threshold_results or all_results, key=lambda x: x[1]
1471
+ )
1472
+ # to pass final temperature for prompt_reset_on_temperature
1473
+ decode_result = (
1474
+ decode_result[0],
1475
+ decode_result[1],
1476
+ temperature,
1477
+ decode_result[3],
1478
+ )
1479
+
1480
+ return decode_result
1481
+
1482
+ def get_prompt(
1483
+ self,
1484
+ tokenizer: Tokenizer,
1485
+ previous_tokens: List[int],
1486
+ without_timestamps: bool = False,
1487
+ prefix: Optional[str] = None,
1488
+ hotwords: Optional[str] = None,
1489
+ ) -> List[int]:
1490
+ prompt = []
1491
+
1492
+ if previous_tokens or (hotwords and not prefix):
1493
+ prompt.append(tokenizer.sot_prev)
1494
+ if hotwords and not prefix:
1495
+ hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
1496
+ if len(hotwords_tokens) >= self.max_length // 2:
1497
+ hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
1498
+ prompt.extend(hotwords_tokens)
1499
+ if previous_tokens:
1500
+ prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
1501
+
1502
+ prompt.extend(tokenizer.sot_sequence)
1503
+
1504
+ if without_timestamps:
1505
+ prompt.append(tokenizer.no_timestamps)
1506
+
1507
+ if prefix:
1508
+ prefix_tokens = tokenizer.encode(" " + prefix.strip())
1509
+ if len(prefix_tokens) >= self.max_length // 2:
1510
+ prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
1511
+ if not without_timestamps:
1512
+ prompt.append(tokenizer.timestamp_begin)
1513
+ prompt.extend(prefix_tokens)
1514
+
1515
+ return prompt
1516
+
1517
+ def add_word_timestamps(
1518
+ self,
1519
+ segments: List[dict],
1520
+ tokenizer: Tokenizer,
1521
+ encoder_output: ctranslate2.StorageView,
1522
+ num_frames: int,
1523
+ prepend_punctuations: str,
1524
+ append_punctuations: str,
1525
+ last_speech_timestamp: float,
1526
+ ) -> float:
1527
+ if len(segments) == 0:
1528
+ return
1529
+
1530
+ text_tokens = []
1531
+ text_tokens_per_segment = []
1532
+ for segment in segments:
1533
+ segment_tokens = [
1534
+ [token for token in subsegment["tokens"] if token < tokenizer.eot]
1535
+ for subsegment in segment
1536
+ ]
1537
+ text_tokens.append(list(itertools.chain.from_iterable(segment_tokens)))
1538
+ text_tokens_per_segment.append(segment_tokens)
1539
+
1540
+ alignments = self.find_alignment(
1541
+ tokenizer, text_tokens, encoder_output, num_frames
1542
+ )
1543
+ median_max_durations = []
1544
+ for alignment in alignments:
1545
+ word_durations = np.array(
1546
+ [word["end"] - word["start"] for word in alignment]
1547
+ )
1548
+ word_durations = word_durations[word_durations.nonzero()]
1549
+ median_duration = (
1550
+ np.median(word_durations) if len(word_durations) > 0 else 0.0
1551
+ )
1552
+ median_duration = min(0.7, float(median_duration))
1553
+ max_duration = median_duration * 2
1554
+
1555
+ # hack: truncate long words at sentence boundaries.
1556
+ # a better segmentation algorithm based on VAD should be able to replace this.
1557
+ if len(word_durations) > 0:
1558
+ sentence_end_marks = ".。!!??"
1559
+ # ensure words at sentence boundaries
1560
+ # are not longer than twice the median word duration.
1561
+ for i in range(1, len(alignment)):
1562
+ if alignment[i]["end"] - alignment[i]["start"] > max_duration:
1563
+ if alignment[i]["word"] in sentence_end_marks:
1564
+ alignment[i]["end"] = alignment[i]["start"] + max_duration
1565
+ elif alignment[i - 1]["word"] in sentence_end_marks:
1566
+ alignment[i]["start"] = alignment[i]["end"] - max_duration
1567
+
1568
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
1569
+ median_max_durations.append((median_duration, max_duration))
1570
+
1571
+ for segment_idx, segment in enumerate(segments):
1572
+ word_index = 0
1573
+ time_offset = segment[0]["seek"] / self.frames_per_second
1574
+ median_duration, max_duration = median_max_durations[segment_idx]
1575
+ for subsegment_idx, subsegment in enumerate(segment):
1576
+ saved_tokens = 0
1577
+ words = []
1578
+
1579
+ while word_index < len(alignments[segment_idx]) and saved_tokens < len(
1580
+ text_tokens_per_segment[segment_idx][subsegment_idx]
1581
+ ):
1582
+ timing = alignments[segment_idx][word_index]
1583
+
1584
+ if timing["word"]:
1585
+ words.append(
1586
+ dict(
1587
+ word=timing["word"],
1588
+ start=round(time_offset + timing["start"], 2),
1589
+ end=round(time_offset + timing["end"], 2),
1590
+ probability=timing["probability"],
1591
+ )
1592
+ )
1593
+
1594
+ saved_tokens += len(timing["tokens"])
1595
+ word_index += 1
1596
+
1597
+ # hack: truncate long words at segment boundaries.
1598
+ # a better segmentation algorithm based on VAD should be able to replace this.
1599
+ if len(words) > 0:
1600
+ # ensure the first and second word after a pause is not longer than
1601
+ # twice the median word duration.
1602
+ if words[0][
1603
+ "end"
1604
+ ] - last_speech_timestamp > median_duration * 4 and (
1605
+ words[0]["end"] - words[0]["start"] > max_duration
1606
+ or (
1607
+ len(words) > 1
1608
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
1609
+ )
1610
+ ):
1611
+ if (
1612
+ len(words) > 1
1613
+ and words[1]["end"] - words[1]["start"] > max_duration
1614
+ ):
1615
+ boundary = max(
1616
+ words[1]["end"] / 2, words[1]["end"] - max_duration
1617
+ )
1618
+ words[0]["end"] = words[1]["start"] = boundary
1619
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
1620
+
1621
+ # prefer the segment-level start timestamp if the first word is too long.
1622
+ if (
1623
+ subsegment["start"] < words[0]["end"]
1624
+ and subsegment["start"] - 0.5 > words[0]["start"]
1625
+ ):
1626
+ words[0]["start"] = max(
1627
+ 0,
1628
+ min(words[0]["end"] - median_duration, subsegment["start"]),
1629
+ )
1630
+ else:
1631
+ subsegment["start"] = words[0]["start"]
1632
+
1633
+ # prefer the segment-level end timestamp if the last word is too long.
1634
+ if (
1635
+ subsegment["end"] > words[-1]["start"]
1636
+ and subsegment["end"] + 0.5 < words[-1]["end"]
1637
+ ):
1638
+ words[-1]["end"] = max(
1639
+ words[-1]["start"] + median_duration, subsegment["end"]
1640
+ )
1641
+ else:
1642
+ subsegment["end"] = words[-1]["end"]
1643
+
1644
+ last_speech_timestamp = subsegment["end"]
1645
+ segments[segment_idx][subsegment_idx]["words"] = words
1646
+ return last_speech_timestamp
1647
+
1648
+ def find_alignment(
1649
+ self,
1650
+ tokenizer: Tokenizer,
1651
+ text_tokens: List[int],
1652
+ encoder_output: ctranslate2.StorageView,
1653
+ num_frames: int,
1654
+ median_filter_width: int = 7,
1655
+ ) -> List[dict]:
1656
+ if len(text_tokens) == 0:
1657
+ return []
1658
+
1659
+ results = self.model.align(
1660
+ encoder_output,
1661
+ tokenizer.sot_sequence,
1662
+ text_tokens,
1663
+ num_frames,
1664
+ median_filter_width=median_filter_width,
1665
+ )
1666
+ return_list = []
1667
+ for result, text_token in zip(results, text_tokens):
1668
+ text_token_probs = result.text_token_probs
1669
+ alignments = result.alignments
1670
+ text_indices = np.array([pair[0] for pair in alignments])
1671
+ time_indices = np.array([pair[1] for pair in alignments])
1672
+
1673
+ words, word_tokens = tokenizer.split_to_word_tokens(
1674
+ text_token + [tokenizer.eot]
1675
+ )
1676
+ if len(word_tokens) <= 1:
1677
+ # return on eot only
1678
+ # >>> np.pad([], (1, 0))
1679
+ # array([0.])
1680
+ # This results in crashes when we lookup jump_times with float, like
1681
+ # IndexError: arrays used as indices must be of integer (or boolean) type
1682
+ return_list.append([])
1683
+ continue
1684
+ word_boundaries = np.pad(
1685
+ np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)
1686
+ )
1687
+ if len(word_boundaries) <= 1:
1688
+ return_list.append([])
1689
+ continue
1690
+
1691
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(
1692
+ bool
1693
+ )
1694
+ jump_times = time_indices[jumps] / self.tokens_per_second
1695
+ start_times = jump_times[word_boundaries[:-1]]
1696
+ end_times = jump_times[word_boundaries[1:]]
1697
+ word_probabilities = [
1698
+ np.mean(text_token_probs[i:j])
1699
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
1700
+ ]
1701
+
1702
+ return_list.append(
1703
+ [
1704
+ dict(
1705
+ word=word,
1706
+ tokens=tokens,
1707
+ start=start,
1708
+ end=end,
1709
+ probability=probability,
1710
+ )
1711
+ for word, tokens, start, end, probability in zip(
1712
+ words, word_tokens, start_times, end_times, word_probabilities
1713
+ )
1714
+ ]
1715
+ )
1716
+ return return_list
1717
+
1718
+ def detect_language(
1719
+ self,
1720
+ audio: Optional[np.ndarray] = None,
1721
+ features: Optional[np.ndarray] = None,
1722
+ vad_filter: bool = False,
1723
+ vad_parameters: Union[dict, VadOptions] = None,
1724
+ language_detection_segments: int = 1,
1725
+ language_detection_threshold: float = 0.5,
1726
+ ) -> Tuple[str, float, List[Tuple[str, float]]]:
1727
+ """
1728
+ Use Whisper to detect the language of the input audio or features.
1729
+
1730
+ Arguments:
1731
+ audio: Input audio signal, must be a 1D float array sampled at 16khz.
1732
+ features: Input Mel spectrogram features, must be a float array with
1733
+ shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
1734
+ Either `audio` or `features` must be provided.
1735
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
1736
+ without speech. This step is using the Silero VAD model.
1737
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
1738
+ parameters and default values in the class `VadOptions`).
1739
+ language_detection_threshold: If the maximum probability of the language tokens is
1740
+ higher than this value, the language is detected.
1741
+ language_detection_segments: Number of segments to consider for the language detection.
1742
+
1743
+ Returns:
1744
+ language: Detected language.
1745
+ languege_probability: Probability of the detected language.
1746
+ all_language_probs: List of tuples with all language names and probabilities.
1747
+ """
1748
+ assert (
1749
+ audio is not None or features is not None
1750
+ ), "Either `audio` or `features` must be provided."
1751
+
1752
+ if audio is not None:
1753
+ if vad_filter:
1754
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
1755
+ audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
1756
+ audio = np.concatenate(audio_chunks, axis=0)
1757
+
1758
+ audio = audio[
1759
+ : language_detection_segments * self.feature_extractor.n_samples
1760
+ ]
1761
+ features = self.feature_extractor(audio)
1762
+
1763
+ features = features[
1764
+ ..., : language_detection_segments * self.feature_extractor.nb_max_frames
1765
+ ]
1766
+
1767
+ detected_language_info = {}
1768
+ for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
1769
+ encoder_output = self.encode(
1770
+ pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
1771
+ )
1772
+ # results is a list of tuple[str, float] with language names and probabilities.
1773
+ results = self.model.detect_language(encoder_output)[0]
1774
+
1775
+ # Parse language names to strip out markers
1776
+ all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
1777
+ # Get top language token and probability
1778
+ language, language_probability = all_language_probs[0]
1779
+ if language_probability > language_detection_threshold:
1780
+ break
1781
+ detected_language_info.setdefault(language, []).append(language_probability)
1782
+ else:
1783
+ # If no language detected for all segments, the majority vote of the highest
1784
+ # projected languages for all segments is used to determine the language.
1785
+ language = max(
1786
+ detected_language_info,
1787
+ key=lambda lang: len(detected_language_info[lang]),
1788
+ )
1789
+ language_probability = max(detected_language_info[language])
1790
+
1791
+ return language, language_probability, all_language_probs
1792
+
1793
+
1794
+ def restore_speech_timestamps(
1795
+ segments: Iterable[Segment],
1796
+ speech_chunks: List[dict],
1797
+ sampling_rate: int,
1798
+ ) -> Iterable[Segment]:
1799
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
1800
+
1801
+ for segment in segments:
1802
+ if segment.words:
1803
+ words = []
1804
+ for word in segment.words:
1805
+ # Ensure the word start and end times are resolved to the same chunk.
1806
+ middle = (word.start + word.end) / 2
1807
+ chunk_index = ts_map.get_chunk_index(middle)
1808
+ word.start = ts_map.get_original_time(word.start, chunk_index)
1809
+ word.end = ts_map.get_original_time(word.end, chunk_index)
1810
+ words.append(word)
1811
+
1812
+ segment.start = words[0].start
1813
+ segment.end = words[-1].end
1814
+ segment.words = words
1815
+
1816
+ else:
1817
+ segment.start = ts_map.get_original_time(segment.start)
1818
+ segment.end = ts_map.get_original_time(segment.end)
1819
+ return segments
1820
+
1821
+
1822
+ def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
1823
+ segment = np.ascontiguousarray(segment)
1824
+ segment = ctranslate2.StorageView.from_array(segment)
1825
+ return segment
1826
+
1827
+
1828
+ def get_compression_ratio(text: str) -> float:
1829
+ text_bytes = text.encode("utf-8")
1830
+ return len(text_bytes) / len(zlib.compress(text_bytes))
1831
+
1832
+
1833
+ def get_suppressed_tokens(
1834
+ tokenizer: Tokenizer,
1835
+ suppress_tokens: Tuple[int],
1836
+ ) -> Optional[List[int]]:
1837
+ if -1 in suppress_tokens:
1838
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
1839
+ suppress_tokens.extend(tokenizer.non_speech_tokens)
1840
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
1841
+ suppress_tokens = [] # interpret empty string as an empty list
1842
+ else:
1843
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
1844
+
1845
+ suppress_tokens.extend(
1846
+ [
1847
+ tokenizer.transcribe,
1848
+ tokenizer.translate,
1849
+ tokenizer.sot,
1850
+ tokenizer.sot_prev,
1851
+ tokenizer.sot_lm,
1852
+ ]
1853
+ )
1854
+
1855
+ return tuple(sorted(set(suppress_tokens)))
1856
+
1857
+
1858
+ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
1859
+ # merge prepended punctuations
1860
+ i = len(alignment) - 2
1861
+ j = len(alignment) - 1
1862
+ while i >= 0:
1863
+ previous = alignment[i]
1864
+ following = alignment[j]
1865
+ if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
1866
+ # prepend it to the following word
1867
+ following["word"] = previous["word"] + following["word"]
1868
+ following["tokens"] = previous["tokens"] + following["tokens"]
1869
+ previous["word"] = ""
1870
+ previous["tokens"] = []
1871
+ else:
1872
+ j = i
1873
+ i -= 1
1874
+
1875
+ # merge appended punctuations
1876
+ i = 0
1877
+ j = 1
1878
+ while j < len(alignment):
1879
+ previous = alignment[i]
1880
+ following = alignment[j]
1881
+ if not previous["word"].endswith(" ") and following["word"] in appended:
1882
+ # append it to the previous word
1883
+ previous["word"] = previous["word"] + following["word"]
1884
+ previous["tokens"] = previous["tokens"] + following["tokens"]
1885
+ following["word"] = ""
1886
+ following["tokens"] = []
1887
+ else:
1888
+ i = j
1889
+ j += 1
whisper_live/transcriber/transcriber_openvino.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import os
3
+
4
+ import openvino_genai as ov_genai
5
+ import huggingface_hub as hf_hub
6
+
7
+
8
+ class WhisperOpenVINO(object):
9
+ def __init__(self, model_id="OpenVINO/whisper-tiny-fp16-ov", device="CPU", language="en", task="transcribe"):
10
+ model_path = model_id.split('/')[-1]
11
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "openvino_whisper_models")
12
+ os.makedirs(cache_dir, exist_ok=True)
13
+ model_path = os.path.join(cache_dir, model_path)
14
+ if not os.path.exists(model_path):
15
+ hf_hub.snapshot_download(model_id, local_dir=model_path)
16
+ self.model = ov_genai.WhisperPipeline(str(model_path), device=device)
17
+ self.language = language
18
+ self.task = task
19
+
20
+ def transcribe(self, input_audio):
21
+ outputs = self.model.generate(input_audio, return_timestamps=True, language=self.language, task=self.task)
22
+ outputs = [seg for seg in outputs.chunks]
23
+ return outputs
whisper_live/transcriber/transcriber_tensorrt.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import math
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from whisper.tokenizer import get_tokenizer
12
+ from whisper_live.transcriber.tensorrt_utils import (
13
+ mel_filters,
14
+ load_audio_wav_format,
15
+ pad_or_trim,
16
+ load_audio
17
+ )
18
+
19
+ import tensorrt_llm
20
+ import tensorrt_llm.logger as logger
21
+ from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
22
+ trt_dtype_to_torch)
23
+ from tensorrt_llm.bindings import GptJsonConfig, KVCacheType
24
+ from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
25
+ from tensorrt_llm.runtime.session import Session, TensorInfo
26
+ if PYTHON_BINDINGS:
27
+ from tensorrt_llm.runtime import ModelRunnerCpp
28
+
29
+ SAMPLE_RATE = 16000
30
+ N_FFT = 400
31
+ HOP_LENGTH = 160
32
+ CHUNK_LENGTH = 30
33
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
34
+
35
+ def read_config(component, engine_dir):
36
+ config_path = engine_dir / component / 'config.json'
37
+ with open(config_path, 'r') as f:
38
+ config = json.load(f)
39
+ model_config = OrderedDict()
40
+ model_config.update(config['pretrained_config'])
41
+ model_config.update(config['build_config'])
42
+ return model_config
43
+
44
+
45
+ def remove_tensor_padding(input_tensor,
46
+ input_tensor_lengths=None,
47
+ pad_value=None):
48
+ if pad_value:
49
+ assert input_tensor_lengths is None, "input_tensor_lengths should be None when pad_value is provided"
50
+ # Text tensor case: batch, seq_len
51
+ assert torch.all(
52
+ input_tensor[:, 0] != pad_value
53
+ ), "First token in each sequence should not be pad_value"
54
+ assert input_tensor_lengths is None
55
+
56
+ # Create a mask for all non-pad tokens
57
+ mask = input_tensor != pad_value
58
+
59
+ # Apply the mask to input_tensor to remove pad tokens
60
+ output_tensor = input_tensor[mask].view(1, -1)
61
+
62
+ else:
63
+ # Audio tensor case: batch, seq_len, feature_len
64
+ # position_ids case: batch, seq_len
65
+ assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
66
+
67
+ # Initialize a list to collect valid sequences
68
+ valid_sequences = []
69
+
70
+ for i in range(input_tensor.shape[0]):
71
+ valid_length = input_tensor_lengths[i]
72
+ valid_sequences.append(input_tensor[i, :valid_length])
73
+
74
+ # Concatenate all valid sequences along the batch dimension
75
+ output_tensor = torch.cat(valid_sequences, dim=0)
76
+ return output_tensor
77
+
78
+
79
+ class WhisperEncoding:
80
+
81
+ def __init__(self, engine_dir):
82
+ self.session = self.get_session(engine_dir)
83
+ config = read_config('encoder', engine_dir)
84
+ self.n_mels = config['n_mels']
85
+ self.dtype = config['dtype']
86
+ self.num_languages = config['num_languages']
87
+ self.encoder_config = config
88
+
89
+ def get_session(self, engine_dir):
90
+ serialize_path = engine_dir / 'encoder' / 'rank0.engine'
91
+ with open(serialize_path, 'rb') as f:
92
+ session = Session.from_serialized_engine(f.read())
93
+ return session
94
+
95
+ def get_audio_features(self,
96
+ mel,
97
+ mel_input_lengths,
98
+ encoder_downsampling_factor=2):
99
+ if isinstance(mel, list):
100
+ longest_mel = max([f.shape[-1] for f in mel])
101
+ mel = [
102
+ torch.nn.functional.pad(f, (0, longest_mel - f.shape[-1]),
103
+ mode='constant') for f in mel
104
+ ]
105
+ mel = torch.cat(mel, dim=0).type(
106
+ str_dtype_to_torch("float16")).contiguous()
107
+ bsz, seq_len = mel.shape[0], mel.shape[2]
108
+ position_ids = torch.arange(
109
+ math.ceil(seq_len / encoder_downsampling_factor),
110
+ dtype=torch.int32,
111
+ device=mel.device).expand(bsz, -1).contiguous()
112
+ if self.encoder_config['plugin_config']['remove_input_padding']:
113
+ # mel B,D,T -> B,T,D -> BxT, D
114
+ mel = mel.transpose(1, 2)
115
+ mel = remove_tensor_padding(mel, mel_input_lengths)
116
+ position_ids = remove_tensor_padding(
117
+ position_ids, mel_input_lengths // encoder_downsampling_factor)
118
+ inputs = OrderedDict()
119
+ inputs['input_features'] = mel
120
+ inputs['input_lengths'] = mel_input_lengths
121
+ inputs['position_ids'] = position_ids
122
+
123
+ output_list = [
124
+ TensorInfo('input_features', str_dtype_to_trt(self.dtype),
125
+ mel.shape),
126
+ TensorInfo('input_lengths', str_dtype_to_trt('int32'),
127
+ mel_input_lengths.shape),
128
+ TensorInfo('position_ids', str_dtype_to_trt('int32'),
129
+ inputs['position_ids'].shape)
130
+ ]
131
+
132
+ output_info = (self.session).infer_shapes(output_list)
133
+
134
+ logger.debug(f'output info {output_info}')
135
+ outputs = {
136
+ t.name: torch.empty(tuple(t.shape),
137
+ dtype=trt_dtype_to_torch(t.dtype),
138
+ device='cuda')
139
+ for t in output_info
140
+ }
141
+ stream = torch.cuda.current_stream()
142
+ ok = self.session.run(inputs=inputs,
143
+ outputs=outputs,
144
+ stream=stream.cuda_stream)
145
+ assert ok, 'Engine execution failed'
146
+ stream.synchronize()
147
+ encoder_output = outputs['encoder_output']
148
+ encoder_output_lengths = mel_input_lengths // encoder_downsampling_factor
149
+ return encoder_output, encoder_output_lengths
150
+
151
+
152
+ class WhisperDecoding:
153
+
154
+ def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
155
+
156
+ self.decoder_config = read_config('decoder', engine_dir)
157
+ self.decoder_generation_session = self.get_session(
158
+ engine_dir, runtime_mapping, debug_mode)
159
+
160
+ def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
161
+ serialize_path = engine_dir / 'decoder' / 'rank0.engine'
162
+ with open(serialize_path, "rb") as f:
163
+ decoder_engine_buffer = f.read()
164
+
165
+ decoder_model_config = ModelConfig(
166
+ max_batch_size=self.decoder_config['max_batch_size'],
167
+ max_beam_width=self.decoder_config['max_beam_width'],
168
+ num_heads=self.decoder_config['num_attention_heads'],
169
+ num_kv_heads=self.decoder_config['num_attention_heads'],
170
+ hidden_size=self.decoder_config['hidden_size'],
171
+ vocab_size=self.decoder_config['vocab_size'],
172
+ cross_attention=True,
173
+ num_layers=self.decoder_config['num_hidden_layers'],
174
+ gpt_attention_plugin=self.decoder_config['plugin_config']
175
+ ['gpt_attention_plugin'],
176
+ remove_input_padding=self.decoder_config['plugin_config']
177
+ ['remove_input_padding'],
178
+ kv_cache_type=KVCacheType.PAGED
179
+ if self.decoder_config['plugin_config']['paged_kv_cache'] == True
180
+ else KVCacheType.CONTINUOUS,
181
+ has_position_embedding=self.
182
+ decoder_config['has_position_embedding'],
183
+ dtype=self.decoder_config['dtype'],
184
+ has_token_type_embedding=False,
185
+ )
186
+ decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
187
+ decoder_model_config,
188
+ decoder_engine_buffer,
189
+ runtime_mapping,
190
+ debug_mode=debug_mode)
191
+
192
+ return decoder_generation_session
193
+
194
+ def generate(self,
195
+ decoder_input_ids,
196
+ encoder_outputs,
197
+ encoder_max_input_length,
198
+ encoder_input_lengths,
199
+ eot_id,
200
+ max_new_tokens=40,
201
+ num_beams=1):
202
+ batch_size = decoder_input_ids.shape[0]
203
+ decoder_input_lengths = torch.tensor([
204
+ decoder_input_ids.shape[-1]
205
+ for _ in range(decoder_input_ids.shape[0])
206
+ ],
207
+ dtype=torch.int32,
208
+ device='cuda')
209
+ decoder_max_input_length = torch.max(decoder_input_lengths).item()
210
+
211
+ cross_attention_mask = torch.ones([
212
+ batch_size, decoder_max_input_length + max_new_tokens,
213
+ encoder_max_input_length
214
+ ]).int().cuda()
215
+ # generation config
216
+ sampling_config = SamplingConfig(end_id=eot_id,
217
+ pad_id=eot_id,
218
+ num_beams=num_beams)
219
+ self.decoder_generation_session.setup(
220
+ decoder_input_lengths.size(0),
221
+ decoder_max_input_length,
222
+ max_new_tokens,
223
+ beam_width=num_beams,
224
+ encoder_max_input_length=encoder_max_input_length)
225
+
226
+ torch.cuda.synchronize()
227
+
228
+ decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
229
+ if self.decoder_config['plugin_config']['remove_input_padding']:
230
+ # 50256 is the index of <pad> for all whisper models' decoder
231
+ WHISPER_PAD_TOKEN_ID = 50256
232
+ decoder_input_ids = remove_tensor_padding(
233
+ decoder_input_ids, pad_value=WHISPER_PAD_TOKEN_ID)
234
+ if encoder_outputs.dim() == 3:
235
+ encoder_output_lens = torch.full((encoder_outputs.shape[0], ),
236
+ encoder_outputs.shape[1],
237
+ dtype=torch.int32,
238
+ device='cuda')
239
+
240
+ encoder_outputs = remove_tensor_padding(encoder_outputs,
241
+ encoder_output_lens)
242
+ output_ids = self.decoder_generation_session.decode(
243
+ decoder_input_ids,
244
+ decoder_input_lengths,
245
+ sampling_config,
246
+ encoder_output=encoder_outputs,
247
+ encoder_input_lengths=encoder_input_lengths,
248
+ cross_attention_mask=cross_attention_mask,
249
+ )
250
+ torch.cuda.synchronize()
251
+
252
+ # get the list of int from output_ids tensor
253
+ output_ids = output_ids.cpu().numpy().tolist()
254
+ return output_ids
255
+
256
+
257
+ class WhisperTRTLLM(object):
258
+
259
+ def __init__(self,
260
+ engine_dir,
261
+ assets_dir=None,
262
+ device=None,
263
+ is_multilingual=False,
264
+ language="en",
265
+ task="transcribe",
266
+ use_py_session=False,
267
+ num_beams=1,
268
+ debug_mode=False,
269
+ max_output_len=96):
270
+ world_size = 1
271
+ runtime_rank = tensorrt_llm.mpi_rank()
272
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
273
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
274
+ engine_dir = Path(engine_dir)
275
+ encoder_config = read_config('encoder', engine_dir)
276
+ decoder_config = read_config('decoder', engine_dir)
277
+ self.n_mels = encoder_config['n_mels']
278
+ self.num_languages = encoder_config['num_languages']
279
+ is_multilingual = (decoder_config['vocab_size'] >= 51865)
280
+
281
+ self.device = device
282
+ self.tokenizer = get_tokenizer(
283
+ is_multilingual,
284
+ num_languages=self.num_languages,
285
+ language=language,
286
+ task=task,
287
+ )
288
+
289
+ if use_py_session:
290
+ self.encoder = WhisperEncoding(engine_dir)
291
+ self.decoder = WhisperDecoding(engine_dir,
292
+ runtime_mapping,
293
+ debug_mode=False)
294
+ else:
295
+ json_config = GptJsonConfig.parse_file(engine_dir / 'decoder' /
296
+ 'config.json')
297
+ assert json_config.model_config.supports_inflight_batching
298
+ runner_kwargs = dict(engine_dir=engine_dir,
299
+ is_enc_dec=True,
300
+ max_batch_size=1,
301
+ max_input_len=3000,
302
+ max_output_len=max_output_len,
303
+ max_beam_width=num_beams,
304
+ debug_mode=debug_mode,
305
+ kv_cache_free_gpu_memory_fraction=0.9,
306
+ cross_kv_cache_fraction=0.5)
307
+ self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
308
+ self.filters = mel_filters(self.device, self.n_mels, assets_dir)
309
+ self.use_py_session = use_py_session
310
+
311
+ def log_mel_spectrogram(
312
+ self,
313
+ audio: Union[str, np.ndarray, torch.Tensor],
314
+ padding: int = 0,
315
+ return_duration=True
316
+ ):
317
+ """
318
+ Compute the log-Mel spectrogram of
319
+
320
+ Parameters
321
+ ----------
322
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
323
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
324
+
325
+ n_mels: int
326
+ The number of Mel-frequency filters, only 80 and 128 are supported
327
+
328
+ padding: int
329
+ Number of zero samples to pad to the right
330
+
331
+ device: Optional[Union[str, torch.device]]
332
+ If given, the audio tensor is moved to this device before STFT
333
+
334
+ Returns
335
+ -------
336
+ torch.Tensor, shape = (80 or 128, n_frames)
337
+ A Tensor that contains the Mel spectrogram
338
+ """
339
+ if not torch.is_tensor(audio):
340
+ if isinstance(audio, str):
341
+ if audio.endswith('.wav'):
342
+ audio, _ = load_audio_wav_format(audio)
343
+ else:
344
+ audio = load_audio(audio)
345
+ assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}"
346
+ duration = audio.shape[-1] / SAMPLE_RATE
347
+ audio = pad_or_trim(audio, N_SAMPLES)
348
+ audio = audio.astype(np.float32)
349
+ audio = torch.from_numpy(audio)
350
+
351
+ if self.device is not None:
352
+ audio = audio.to(self.device)
353
+ if padding > 0:
354
+ audio = F.pad(audio, (0, padding))
355
+ window = torch.hann_window(N_FFT).to(audio.device)
356
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
357
+ magnitudes = stft[..., :-1].abs()**2
358
+
359
+ mel_spec = self.filters @ magnitudes
360
+
361
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
362
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
363
+ log_spec = (log_spec + 4.0) / 4.0
364
+ if return_duration:
365
+ return log_spec, duration
366
+ else:
367
+ return log_spec
368
+
369
+ def process_batch(
370
+ self,
371
+ mel,
372
+ mel_input_lengths,
373
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
374
+ num_beams=1,
375
+ max_new_tokens=96):
376
+ prompt_id = self.tokenizer.encode(
377
+ text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys()))
378
+
379
+ prompt_id = torch.tensor(prompt_id)
380
+ batch_size = mel.shape[0]
381
+ decoder_input_ids = prompt_id.repeat(batch_size, 1)
382
+ if self.use_py_session:
383
+ encoder_output, encoder_output_lengths = self.encoder.get_audio_features(mel, mel_input_lengths)
384
+ encoder_max_input_length = torch.max(encoder_output_lengths).item()
385
+ output_ids = self.decoder.generate(decoder_input_ids,
386
+ encoder_output,
387
+ encoder_max_input_length,
388
+ encoder_output_lengths,
389
+ self.tokenizer.eot,
390
+ max_new_tokens=max_new_tokens,
391
+ num_beams=num_beams)
392
+ else:
393
+ with torch.no_grad():
394
+ if isinstance(mel, list):
395
+ mel = [
396
+ m.transpose(1, 2).type(
397
+ str_dtype_to_torch("float16")).squeeze(0)
398
+ for m in mel
399
+ ]
400
+ else:
401
+ mel = mel.transpose(1, 2)
402
+ outputs = self.model_runner_cpp.generate(
403
+ batch_input_ids=decoder_input_ids,
404
+ encoder_input_features=mel,
405
+ encoder_output_lengths=mel_input_lengths // 2,
406
+ max_new_tokens=max_new_tokens,
407
+ end_id=self.tokenizer.eot,
408
+ pad_id=self.tokenizer.eot,
409
+ num_beams=num_beams,
410
+ output_sequence_lengths=True,
411
+ return_dict=True)
412
+ torch.cuda.synchronize()
413
+ output_ids = outputs['output_ids'].cpu().numpy().tolist()
414
+ texts = []
415
+ for i in range(len(output_ids)):
416
+ text = self.tokenizer.decode(output_ids[i][0]).strip()
417
+ texts.append(text)
418
+ return texts
419
+
420
+ def transcribe(
421
+ self,
422
+ mel,
423
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
424
+ dtype='float16',
425
+ batch_size=1,
426
+ num_beams=1,
427
+ padding_strategy="max",
428
+ max_new_tokens=96,
429
+ ):
430
+ mel = mel.type(str_dtype_to_torch(dtype))
431
+ mel = mel.unsqueeze(0)
432
+ # repeat the mel spectrogram to match the batch size
433
+ mel = mel.repeat(batch_size, 1, 1)
434
+ if padding_strategy == "longest":
435
+ pass
436
+ else:
437
+ mel = torch.nn.functional.pad(mel, (0, 3000 - mel.shape[2]))
438
+ features_input_lengths = torch.full((mel.shape[0], ),
439
+ mel.shape[2],
440
+ dtype=torch.int32,
441
+ device=mel.device)
442
+
443
+ predictions = self.process_batch(
444
+ mel,
445
+ features_input_lengths,
446
+ text_prefix,
447
+ num_beams,
448
+ max_new_tokens=max_new_tokens
449
+ )
450
+ prediction = predictions[0]
451
+
452
+ # remove all special tokens in the prediction
453
+ prediction = re.sub(r'<\|.*?\|>', '', prediction)
454
+ return prediction.strip()
455
+
456
+
457
+ def decode_wav_file(
458
+ model,
459
+ mel,
460
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
461
+ dtype='float16',
462
+ batch_size=1,
463
+ num_beams=1,
464
+ normalizer=None,
465
+ mel_filters_dir=None):
466
+
467
+ mel = mel.type(str_dtype_to_torch(dtype))
468
+ mel = mel.unsqueeze(0)
469
+ # repeat the mel spectrogram to match the batch size
470
+ mel = mel.repeat(batch_size, 1, 1)
471
+ predictions = model.process_batch(mel, text_prefix, num_beams)
472
+ prediction = predictions[0]
473
+
474
+ # remove all special tokens in the prediction
475
+ prediction = re.sub(r'<\|.*?\|>', '', prediction)
476
+ if normalizer:
477
+ prediction = normalizer(prediction)
478
+
479
+ return prediction.strip()
whisper_live/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import textwrap
3
+ import scipy
4
+ import numpy as np
5
+ import av
6
+ from pathlib import Path
7
+
8
+
9
+ def clear_screen():
10
+ """Clears the console screen."""
11
+ os.system("cls" if os.name == "nt" else "clear")
12
+
13
+
14
+ def print_transcript(text):
15
+ """Prints formatted transcript text."""
16
+ wrapper = textwrap.TextWrapper(width=60)
17
+ for line in wrapper.wrap(text="".join(text)):
18
+ print(line)
19
+
20
+
21
+ def format_time(s):
22
+ """Convert seconds (float) to SRT time format."""
23
+ hours = int(s // 3600)
24
+ minutes = int((s % 3600) // 60)
25
+ seconds = int(s % 60)
26
+ milliseconds = int((s - int(s)) * 1000)
27
+ return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
28
+
29
+
30
+ def create_srt_file(segments, resampled_file):
31
+ with open(resampled_file, 'w', encoding='utf-8') as srt_file:
32
+ segment_number = 1
33
+ for segment in segments:
34
+ start_time = format_time(float(segment['start']))
35
+ end_time = format_time(float(segment['end']))
36
+ text = segment['text']
37
+
38
+ srt_file.write(f"{segment_number}\n")
39
+ srt_file.write(f"{start_time} --> {end_time}\n")
40
+ srt_file.write(f"{text}\n\n")
41
+
42
+ segment_number += 1
43
+
44
+
45
+ def resample(file: str, sr: int = 16000):
46
+ """
47
+ Resample the audio file to 16kHz.
48
+
49
+ Args:
50
+ file (str): The audio file to open
51
+ sr (int): The sample rate to resample the audio if necessary
52
+
53
+ Returns:
54
+ resampled_file (str): The resampled audio file
55
+ """
56
+ container = av.open(file)
57
+ stream = next(s for s in container.streams if s.type == 'audio')
58
+
59
+ resampler = av.AudioResampler(
60
+ format='s16',
61
+ layout='mono',
62
+ rate=sr,
63
+ )
64
+
65
+ resampled_file = Path(file).stem + "_resampled.wav"
66
+ output_container = av.open(resampled_file, mode='w')
67
+ output_stream = output_container.add_stream('pcm_s16le', rate=sr)
68
+ output_stream.layout = 'mono'
69
+
70
+ for frame in container.decode(audio=0):
71
+ frame.pts = None
72
+ resampled_frames = resampler.resample(frame)
73
+ if resampled_frames is not None:
74
+ for resampled_frame in resampled_frames:
75
+ for packet in output_stream.encode(resampled_frame):
76
+ output_container.mux(packet)
77
+
78
+ for packet in output_stream.encode(None):
79
+ output_container.mux(packet)
80
+
81
+ output_container.close()
82
+ return resampled_file
whisper_live/vad.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+ import numpy as np
5
+ import onnxruntime
6
+ import warnings
7
+
8
+
9
+ class VoiceActivityDetection():
10
+
11
+ def __init__(self, force_onnx_cpu=True):
12
+ path = self.download()
13
+
14
+ opts = onnxruntime.SessionOptions()
15
+ opts.log_severity_level = 3
16
+
17
+ opts.inter_op_num_threads = 1
18
+ opts.intra_op_num_threads = 1
19
+
20
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
21
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
22
+ else:
23
+ self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
24
+
25
+ self.reset_states()
26
+ if '16k' in path:
27
+ warnings.warn('This model support only 16000 sampling rate!')
28
+ self.sample_rates = [16000]
29
+ else:
30
+ self.sample_rates = [8000, 16000]
31
+
32
+ def _validate_input(self, x, sr: int):
33
+ if x.dim() == 1:
34
+ x = x.unsqueeze(0)
35
+ if x.dim() > 2:
36
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
37
+
38
+ if sr != 16000 and (sr % 16000 == 0):
39
+ step = sr // 16000
40
+ x = x[:,::step]
41
+ sr = 16000
42
+
43
+ if sr not in self.sample_rates:
44
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
45
+ if sr / x.shape[1] > 31.25:
46
+ raise ValueError("Input audio chunk is too short")
47
+
48
+ return x, sr
49
+
50
+ def reset_states(self, batch_size=1):
51
+ self._state = torch.zeros((2, batch_size, 128)).float()
52
+ self._context = torch.zeros(0)
53
+ self._last_sr = 0
54
+ self._last_batch_size = 0
55
+
56
+ def __call__(self, x, sr: int):
57
+
58
+ x, sr = self._validate_input(x, sr)
59
+ num_samples = 512 if sr == 16000 else 256
60
+
61
+ if x.shape[-1] != num_samples:
62
+ raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
63
+
64
+ batch_size = x.shape[0]
65
+ context_size = 64 if sr == 16000 else 32
66
+
67
+ if not self._last_batch_size:
68
+ self.reset_states(batch_size)
69
+ if (self._last_sr) and (self._last_sr != sr):
70
+ self.reset_states(batch_size)
71
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
72
+ self.reset_states(batch_size)
73
+
74
+ if not len(self._context):
75
+ self._context = torch.zeros(batch_size, context_size)
76
+
77
+ x = torch.cat([self._context, x], dim=1)
78
+ if sr in [8000, 16000]:
79
+ ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
80
+ ort_outs = self.session.run(None, ort_inputs)
81
+ out, state = ort_outs
82
+ self._state = torch.from_numpy(state)
83
+ else:
84
+ raise ValueError()
85
+
86
+ self._context = x[..., -context_size:]
87
+ self._last_sr = sr
88
+ self._last_batch_size = batch_size
89
+
90
+ out = torch.from_numpy(out)
91
+ return out
92
+
93
+ def audio_forward(self, x, sr: int):
94
+ outs = []
95
+ x, sr = self._validate_input(x, sr)
96
+ self.reset_states()
97
+ num_samples = 512 if sr == 16000 else 256
98
+
99
+ if x.shape[1] % num_samples:
100
+ pad_num = num_samples - (x.shape[1] % num_samples)
101
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
102
+
103
+ for i in range(0, x.shape[1], num_samples):
104
+ wavs_batch = x[:, i:i+num_samples]
105
+ out_chunk = self.__call__(wavs_batch, sr)
106
+ outs.append(out_chunk)
107
+
108
+ stacked = torch.cat(outs, dim=1)
109
+ return stacked.cpu()
110
+
111
+ @staticmethod
112
+ def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
113
+ target_dir = os.path.expanduser("~/.cache/whisper-live/")
114
+
115
+ # Ensure the target directory exists
116
+ os.makedirs(target_dir, exist_ok=True)
117
+
118
+ # Define the target file path
119
+ model_filename = os.path.join(target_dir, "silero_vad.onnx")
120
+
121
+ # Check if the model file already exists
122
+ if not os.path.exists(model_filename):
123
+ # If it doesn't exist, download the model using wget
124
+ try:
125
+ subprocess.run(["wget", "-O", model_filename, model_url], check=True)
126
+ except subprocess.CalledProcessError:
127
+ print("Failed to download the model using wget.")
128
+ return model_filename
129
+
130
+
131
+ class VoiceActivityDetector:
132
+ def __init__(self, threshold=0.5, frame_rate=16000):
133
+ """
134
+ Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.
135
+
136
+ Args:
137
+ threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
138
+ """
139
+ self.model = VoiceActivityDetection()
140
+ self.threshold = threshold
141
+ self.frame_rate = frame_rate
142
+
143
+ def __call__(self, audio_frame):
144
+ """
145
+ Determines if the given audio frame contains speech by comparing the detected speech probability against
146
+ the threshold.
147
+
148
+ Args:
149
+ audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
150
+ NumPy array of audio samples.
151
+
152
+ Returns:
153
+ bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
154
+ False otherwise.
155
+ """
156
+ speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
157
+ return torch.any(speech_probs > self.threshold).item()