nuernie
commited on
Commit
·
7222c68
1
Parent(s):
d17f6c7
initial commit
Browse files- .gitignore +127 -0
- Dockerfile +17 -0
- LICENSE +21 -0
- README.md +209 -11
- TensorRT_whisper.md +47 -0
- app.py +23 -0
- docker/Dockerfile.cpu +25 -0
- docker/Dockerfile.gpu +26 -0
- docker/Dockerfile.openvino +19 -0
- docker/Dockerfile.tensorrt +30 -0
- requirements.txt +7 -0
- requirements/client.txt +4 -0
- requirements/server.txt +21 -0
- run_server.py +54 -0
- scripts/build_whisper_tensorrt.sh +120 -0
- scripts/setup.sh +3 -0
- setup.py +67 -0
- tests/__init__.py +0 -0
- tests/test_client.py +162 -0
- tests/test_server.py +148 -0
- tests/test_vad.py +26 -0
- whisper_live/__init__.py +0 -0
- whisper_live/__version__.py +1 -0
- whisper_live/backend/__init__.py +0 -0
- whisper_live/backend/base.py +361 -0
- whisper_live/backend/faster_whisper_backend.py +216 -0
- whisper_live/backend/openvino_backend.py +148 -0
- whisper_live/backend/trt_backend.py +210 -0
- whisper_live/client.py +782 -0
- whisper_live/server.py +446 -0
- whisper_live/transcriber/__init__.py +0 -0
- whisper_live/transcriber/tensorrt_utils.py +364 -0
- whisper_live/transcriber/transcriber_faster_whisper.py +1889 -0
- whisper_live/transcriber/transcriber_openvino.py +23 -0
- whisper_live/transcriber/transcriber_tensorrt.py +479 -0
- whisper_live/utils.py +82 -0
- whisper_live/vad.py +157 -0
.gitignore
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Editors
|
2 |
+
.vscode/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
# Vagrant
|
6 |
+
.vagrant/
|
7 |
+
|
8 |
+
# Mac/OSX
|
9 |
+
.DS_Store
|
10 |
+
|
11 |
+
# Windows
|
12 |
+
Thumbs.db
|
13 |
+
|
14 |
+
# Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.nox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
.cache
|
59 |
+
nosetests.xml
|
60 |
+
coverage.xml
|
61 |
+
*.cover
|
62 |
+
.hypothesis/
|
63 |
+
.pytest_cache/
|
64 |
+
|
65 |
+
# Translations
|
66 |
+
*.mo
|
67 |
+
*.pot
|
68 |
+
|
69 |
+
# Django stuff:
|
70 |
+
*.log
|
71 |
+
local_settings.py
|
72 |
+
db.sqlite3
|
73 |
+
|
74 |
+
# Flask stuff:
|
75 |
+
instance/
|
76 |
+
.webassets-cache
|
77 |
+
|
78 |
+
# Scrapy stuff:
|
79 |
+
.scrapy
|
80 |
+
|
81 |
+
# Sphinx documentation
|
82 |
+
docs/_build/
|
83 |
+
|
84 |
+
# PyBuilder
|
85 |
+
target/
|
86 |
+
|
87 |
+
# Jupyter Notebook
|
88 |
+
.ipynb_checkpoints
|
89 |
+
|
90 |
+
# IPython
|
91 |
+
profile_default/
|
92 |
+
ipython_config.py
|
93 |
+
|
94 |
+
# pyenv
|
95 |
+
.python-version
|
96 |
+
|
97 |
+
# celery beat schedule file
|
98 |
+
celerybeat-schedule
|
99 |
+
|
100 |
+
# SageMath parsed files
|
101 |
+
*.sage.py
|
102 |
+
|
103 |
+
# Environments
|
104 |
+
.env
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
docs/
|
113 |
+
|
114 |
+
# Spyder project settings
|
115 |
+
.spyderproject
|
116 |
+
.spyproject
|
117 |
+
|
118 |
+
# Rope project settings
|
119 |
+
.ropeproject
|
120 |
+
|
121 |
+
# mkdocs documentation
|
122 |
+
/site
|
123 |
+
|
124 |
+
# mypy
|
125 |
+
.mypy_cache/
|
126 |
+
.dmypy.json
|
127 |
+
dmypy.json
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update && \
|
6 |
+
apt-get install -y portaudio19-dev python3-dev gcc && \
|
7 |
+
apt-get clean && \
|
8 |
+
rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
COPY requirements.txt .
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
COPY . .
|
14 |
+
|
15 |
+
EXPOSE 7860
|
16 |
+
|
17 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Vineet Suryan, Collabora Ltd.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,209 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WhisperLive
|
2 |
+
|
3 |
+
<h2 align="center">
|
4 |
+
<a href="https://www.youtube.com/watch?v=0PHWCApIcCI"><img
|
5 |
+
src="https://img.youtube.com/vi/0PHWCApIcCI/0.jpg" style="background-color:rgba(0,0,0,0);" height=300 alt="WhisperLive"></a>
|
6 |
+
<br><br>A nearly-live implementation of OpenAI's Whisper.
|
7 |
+
<br><br>
|
8 |
+
</h2>
|
9 |
+
|
10 |
+
This project is a real-time transcription application that uses the OpenAI Whisper model
|
11 |
+
to convert speech input into text output. It can be used to transcribe both live audio
|
12 |
+
input from microphone and pre-recorded audio files.
|
13 |
+
|
14 |
+
- [Installation](#installation)
|
15 |
+
- [Getting Started](#getting-started)
|
16 |
+
- [Running the Server](#running-the-server)
|
17 |
+
- [Running the Client](#running-the-client)
|
18 |
+
- [Browser Extensions](#browser-extensions)
|
19 |
+
- [Whisper Live Server in Docker](#whisper-live-server-in-docker)
|
20 |
+
- [Future Work](#future-work)
|
21 |
+
- [Contact](#contact)
|
22 |
+
- [Citations](#citations)
|
23 |
+
|
24 |
+
## Installation
|
25 |
+
- Install PyAudio
|
26 |
+
```bash
|
27 |
+
bash scripts/setup.sh
|
28 |
+
```
|
29 |
+
|
30 |
+
- Install whisper-live from pip
|
31 |
+
```bash
|
32 |
+
pip install whisper-live
|
33 |
+
```
|
34 |
+
|
35 |
+
### Setting up NVIDIA/TensorRT-LLM for TensorRT backend
|
36 |
+
- Please follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup of [NVIDIA/TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and for building Whisper-TensorRT engine.
|
37 |
+
|
38 |
+
## Getting Started
|
39 |
+
The server supports 3 backends `faster_whisper`, `tensorrt` and `openvino`. If running `tensorrt` backend follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md)
|
40 |
+
|
41 |
+
### Running the Server
|
42 |
+
- [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) backend
|
43 |
+
```bash
|
44 |
+
python3 run_server.py --port 9090 \
|
45 |
+
--backend faster_whisper
|
46 |
+
|
47 |
+
# running with custom model
|
48 |
+
python3 run_server.py --port 9090 \
|
49 |
+
--backend faster_whisper \
|
50 |
+
-fw "/path/to/custom/faster/whisper/model"
|
51 |
+
```
|
52 |
+
|
53 |
+
- TensorRT backend. Currently, we recommend to only use the docker setup for TensorRT. Follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) which works as expected. Make sure to build your TensorRT Engines before running the server with TensorRT backend.
|
54 |
+
```bash
|
55 |
+
# Run English only model
|
56 |
+
python3 run_server.py -p 9090 \
|
57 |
+
-b tensorrt \
|
58 |
+
-trt /home/TensorRT-LLM/examples/whisper/whisper_small_en
|
59 |
+
|
60 |
+
# Run Multilingual model
|
61 |
+
python3 run_server.py -p 9090 \
|
62 |
+
-b tensorrt \
|
63 |
+
-trt /home/TensorRT-LLM/examples/whisper/whisper_small \
|
64 |
+
-m
|
65 |
+
```
|
66 |
+
|
67 |
+
- WhisperLive now supports the [OpenVINO](https://github.com/openvinotoolkit/openvino) backend for efficient inference on Intel CPUs, iGPU and dGPUs. Currently, we tested the models uploaded to [huggingface by OpenVINO](https://huggingface.co/OpenVINO?search_models=whisper).
|
68 |
+
- > **Docker Recommended:** Running WhisperLive with OpenVINO inside Docker automatically enables GPU support (iGPU/dGPU) without requiring additional host setup.
|
69 |
+
- > **Native (non-Docker) Use:** If you prefer running outside Docker, ensure the Intel drivers and OpenVINO runtime are installed and properly configured on your system. Refer to the documentation for [installing OpenVINO](https://docs.openvino.ai/2025/get-started/install-openvino.html?PACKAGE=OPENVINO_BASE&VERSION=v_2025_0_0&OP_SYSTEM=LINUX&DISTRIBUTION=PIP#).
|
70 |
+
|
71 |
+
```
|
72 |
+
python3 run_server.py -p 9090 -b openvino
|
73 |
+
```
|
74 |
+
|
75 |
+
|
76 |
+
#### Controlling OpenMP Threads
|
77 |
+
To control the number of threads used by OpenMP, you can set the `OMP_NUM_THREADS` environment variable. This is useful for managing CPU resources and ensuring consistent performance. If not specified, `OMP_NUM_THREADS` is set to `1` by default. You can change this by using the `--omp_num_threads` argument:
|
78 |
+
```bash
|
79 |
+
python3 run_server.py --port 9090 \
|
80 |
+
--backend faster_whisper \
|
81 |
+
--omp_num_threads 4
|
82 |
+
```
|
83 |
+
|
84 |
+
#### Single model mode
|
85 |
+
By default, when running the server without specifying a model, the server will instantiate a new whisper model for every client connection. This has the advantage, that the server can use different model sizes, based on the client's requested model size. On the other hand, it also means you have to wait for the model to be loaded upon client connection and you will have increased (V)RAM usage.
|
86 |
+
|
87 |
+
When serving a custom TensorRT model using the `-trt` or a custom faster_whisper model using the `-fw` option, the server will instead only instantiate the custom model once and then reuse it for all client connections.
|
88 |
+
|
89 |
+
If you don't want this, set `--no_single_model`.
|
90 |
+
|
91 |
+
|
92 |
+
### Running the Client
|
93 |
+
- Initializing the client with below parameters:
|
94 |
+
- `lang`: Language of the input audio, applicable only if using a multilingual model.
|
95 |
+
- `translate`: If set to `True` then translate from any language to `en`.
|
96 |
+
- `model`: Whisper model size.
|
97 |
+
- `use_vad`: Whether to use `Voice Activity Detection` on the server.
|
98 |
+
- `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
|
99 |
+
- `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
|
100 |
+
- `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4.
|
101 |
+
- `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600.
|
102 |
+
- `mute_audio_playback`: Whether to mute audio playback when transcribing an audio file. Defaults to False.
|
103 |
+
|
104 |
+
```python
|
105 |
+
from whisper_live.client import TranscriptionClient
|
106 |
+
client = TranscriptionClient(
|
107 |
+
"localhost",
|
108 |
+
9090,
|
109 |
+
lang="en",
|
110 |
+
translate=False,
|
111 |
+
model="small", # also support hf_model => `Systran/faster-whisper-small`
|
112 |
+
use_vad=False,
|
113 |
+
save_output_recording=True, # Only used for microphone input, False by Default
|
114 |
+
output_recording_filename="./output_recording.wav", # Only used for microphone input
|
115 |
+
max_clients=4,
|
116 |
+
max_connection_time=600,
|
117 |
+
mute_audio_playback=False, # Only used for file input, False by Default
|
118 |
+
)
|
119 |
+
```
|
120 |
+
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
|
121 |
+
|
122 |
+
- Transcribe an audio file:
|
123 |
+
```python
|
124 |
+
client("tests/jfk.wav")
|
125 |
+
```
|
126 |
+
|
127 |
+
- To transcribe from microphone:
|
128 |
+
```python
|
129 |
+
client()
|
130 |
+
```
|
131 |
+
|
132 |
+
- To transcribe from a RTSP stream:
|
133 |
+
```python
|
134 |
+
client(rtsp_url="rtsp://admin:[email protected]/rtsp")
|
135 |
+
```
|
136 |
+
|
137 |
+
- To transcribe from a HLS stream:
|
138 |
+
```python
|
139 |
+
client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/bbc_1xtra.isml/bbc_1xtra-audio%3d96000.norewind.m3u8")
|
140 |
+
```
|
141 |
+
|
142 |
+
## Browser Extensions
|
143 |
+
- Run the server with your desired backend as shown [here](https://github.com/collabora/WhisperLive?tab=readme-ov-file#running-the-server).
|
144 |
+
- Transcribe audio directly from your browser using our Chrome or Firefox extensions. Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) and https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md
|
145 |
+
|
146 |
+
## Whisper Live Server in Docker
|
147 |
+
- GPU
|
148 |
+
- Faster-Whisper
|
149 |
+
```bash
|
150 |
+
docker run -it --gpus all -p 9090:9090 ghcr.io/collabora/whisperlive-gpu:latest
|
151 |
+
```
|
152 |
+
|
153 |
+
- TensorRT. Refer to [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup and more tensorrt backend configurations.
|
154 |
+
```bash
|
155 |
+
docker build . -f docker/Dockerfile.tensorrt -t whisperlive-tensorrt
|
156 |
+
docker run -p 9090:9090 --runtime=nvidia --gpus all --entrypoint /bin/bash -it whisperlive-tensorrt
|
157 |
+
|
158 |
+
# Build small.en engine
|
159 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en # float16
|
160 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int8 # int8 weight only quantization
|
161 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int4 # int4 weight only quantization
|
162 |
+
|
163 |
+
# Run server with small.en
|
164 |
+
python3 run_server.py --port 9090 \
|
165 |
+
--backend tensorrt \
|
166 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_float16"
|
167 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int8"
|
168 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int4"
|
169 |
+
```
|
170 |
+
|
171 |
+
- OpenVINO
|
172 |
+
```
|
173 |
+
docker run -it --device=/dev/dri -p 9090:9090 ghcr.io/collabora/whisperlive-openvino
|
174 |
+
```
|
175 |
+
|
176 |
+
- CPU
|
177 |
+
- Faster-whisper
|
178 |
+
```bash
|
179 |
+
docker run -it -p 9090:9090 ghcr.io/collabora/whisperlive-cpu:latest
|
180 |
+
```
|
181 |
+
|
182 |
+
## Future Work
|
183 |
+
- [ ] Add translation to other languages on top of transcription.
|
184 |
+
|
185 |
+
## Contact
|
186 |
+
|
187 |
+
We are available to help you with both Open Source and proprietary AI projects. You can reach us via the Collabora website or [[email protected]](mailto:[email protected]) and [[email protected]](mailto:[email protected]).
|
188 |
+
|
189 |
+
## Citations
|
190 |
+
```bibtex
|
191 |
+
@article{Whisper
|
192 |
+
title = {Robust Speech Recognition via Large-Scale Weak Supervision},
|
193 |
+
url = {https://arxiv.org/abs/2212.04356},
|
194 |
+
author = {Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
|
195 |
+
publisher = {arXiv},
|
196 |
+
year = {2022},
|
197 |
+
}
|
198 |
+
```
|
199 |
+
|
200 |
+
```bibtex
|
201 |
+
@misc{Silero VAD,
|
202 |
+
author = {Silero Team},
|
203 |
+
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
204 |
+
year = {2021},
|
205 |
+
publisher = {GitHub},
|
206 |
+
journal = {GitHub repository},
|
207 |
+
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
208 |
+
email = {[email protected]}
|
209 |
+
}
|
TensorRT_whisper.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WhisperLive-TensorRT
|
2 |
+
We have only tested the TensorRT backend in docker so, we recommend docker for a smooth TensorRT backend setup.
|
3 |
+
**Note**: We use `tensorrt_llm==0.18.2`
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
- Install [docker](https://docs.docker.com/engine/install/)
|
7 |
+
- Install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
8 |
+
|
9 |
+
- Run WhisperLive TensorRT in docker
|
10 |
+
```bash
|
11 |
+
docker build . -f docker/Dockerfile.tensorrt -t whisperlive-tensorrt
|
12 |
+
docker run -p 9090:9090 --runtime=nvidia --gpus all --entrypoint /bin/bash -it whisperlive-tensorrt
|
13 |
+
```
|
14 |
+
|
15 |
+
## Whisper TensorRT Engine
|
16 |
+
- We build `small.en` and `small` multilingual TensorRT engine as examples below. The script logs the path of the directory with Whisper TensorRT engine. We need that model_path to run the server.
|
17 |
+
```bash
|
18 |
+
# convert small.en
|
19 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en # float16
|
20 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int8 # int8 weight only quantization
|
21 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int4 # int4 weight only quantization
|
22 |
+
|
23 |
+
# convert small multilingual model
|
24 |
+
bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small
|
25 |
+
```
|
26 |
+
|
27 |
+
## Run WhisperLive Server with TensorRT Backend
|
28 |
+
```bash
|
29 |
+
# Run English only model
|
30 |
+
python3 run_server.py --port 9090 \
|
31 |
+
--backend tensorrt \
|
32 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_float16"
|
33 |
+
|
34 |
+
# Run Multilingual model
|
35 |
+
python3 run_server.py --port 9090 \
|
36 |
+
--backend tensorrt \
|
37 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
|
38 |
+
--trt_multilingual
|
39 |
+
```
|
40 |
+
|
41 |
+
By default trt_backend uses cpp_session, to use python session pass `--trt_py_session` to run_server.py
|
42 |
+
```bash
|
43 |
+
python3 run_server.py --port 9090 \
|
44 |
+
--backend tensorrt \
|
45 |
+
--trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_float16" \
|
46 |
+
--trt_py_session
|
47 |
+
```
|
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
import uvicorn
|
3 |
+
from whisper_live.server import TranscriptionServer
|
4 |
+
|
5 |
+
app = FastAPI(title="Whisper Live Server")
|
6 |
+
|
7 |
+
@app.on_event("startup")
|
8 |
+
async def startup_event():
|
9 |
+
# Start the transcription server in the background
|
10 |
+
server = TranscriptionServer()
|
11 |
+
server.run(
|
12 |
+
host="0.0.0.0",
|
13 |
+
port=7860, # Hugging Face Spaces uses port 7860
|
14 |
+
backend="faster_whisper", # Using faster_whisper as the backend
|
15 |
+
single_model=True # Use single model mode for better resource usage
|
16 |
+
)
|
17 |
+
|
18 |
+
@app.get("/health")
|
19 |
+
def health_check():
|
20 |
+
return {"status": "healthy"}
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
docker/Dockerfile.cpu
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-bookworm
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
# install lib required for pyaudio
|
6 |
+
RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/*
|
7 |
+
|
8 |
+
# update pip to support for whl.metadata -> less downloading
|
9 |
+
RUN pip install --no-cache-dir -U "pip>=24"
|
10 |
+
|
11 |
+
# create a working directory
|
12 |
+
RUN mkdir /app
|
13 |
+
WORKDIR /app
|
14 |
+
|
15 |
+
# install pytorch, but without the nvidia-libs that are only necessary for gpu
|
16 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
17 |
+
|
18 |
+
# install the requirements for running the whisper-live server
|
19 |
+
COPY requirements/server.txt /app/
|
20 |
+
RUN pip install --no-cache-dir -r server.txt && rm server.txt
|
21 |
+
|
22 |
+
COPY whisper_live /app/whisper_live
|
23 |
+
COPY run_server.py /app
|
24 |
+
|
25 |
+
CMD ["python", "run_server.py"]
|
docker/Dockerfile.gpu
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-bookworm
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
# install lib required for pyaudio
|
6 |
+
RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/*
|
7 |
+
|
8 |
+
# update pip to support for whl.metadata -> less downloading
|
9 |
+
RUN pip install --no-cache-dir -U "pip>=24"
|
10 |
+
|
11 |
+
# create a working directory
|
12 |
+
RUN mkdir /app
|
13 |
+
WORKDIR /app
|
14 |
+
|
15 |
+
# install the requirements for running the whisper-live server
|
16 |
+
COPY requirements/server.txt /app/
|
17 |
+
RUN pip install --no-cache-dir -r server.txt && rm server.txt
|
18 |
+
|
19 |
+
# make the paths of the nvidia libs installed as wheels visible. equivalent to:
|
20 |
+
# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
|
21 |
+
ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib"
|
22 |
+
|
23 |
+
COPY whisper_live /app/whisper_live
|
24 |
+
COPY run_server.py /app
|
25 |
+
|
26 |
+
CMD ["python", "run_server.py"]
|
docker/Dockerfile.openvino
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM openvino/ubuntu22_runtime:latest
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
USER root
|
6 |
+
|
7 |
+
RUN apt update && apt install -y portaudio19-dev python-is-python3 && apt-get clean && rm -rf /var/lib/apt/lists/*
|
8 |
+
|
9 |
+
RUN pip install --no-cache-dir -U "pip>=24"
|
10 |
+
|
11 |
+
RUN mkdir /app
|
12 |
+
WORKDIR /app
|
13 |
+
|
14 |
+
COPY requirements/server.txt /app/
|
15 |
+
RUN pip install --no-cache-dir -r server.txt && rm server.txt
|
16 |
+
|
17 |
+
COPY whisper_live /app/whisper_live
|
18 |
+
COPY run_server.py /app
|
19 |
+
CMD ["python", "run_server.py", "--backend", "openvino"]
|
docker/Dockerfile.tensorrt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS base
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs wget \
|
7 |
+
&& apt install python-is-python3 \
|
8 |
+
&& pip install --upgrade pip setuptools \
|
9 |
+
&& rm -rf /var/lib/apt/lists/*
|
10 |
+
|
11 |
+
FROM base AS devel
|
12 |
+
RUN pip install --no-cache-dir -U tensorrt_llm==0.18.2 --extra-index-url https://pypi.nvidia.com
|
13 |
+
WORKDIR /app
|
14 |
+
RUN git clone -b v0.18.2 https://github.com/NVIDIA/TensorRT-LLM.git \
|
15 |
+
&& mv TensorRT-LLM/examples ./TensorRT-LLM-examples \
|
16 |
+
&& rm -rf TensorRT-LLM
|
17 |
+
|
18 |
+
FROM devel AS release
|
19 |
+
WORKDIR /app
|
20 |
+
COPY assets/ ./assets
|
21 |
+
RUN wget -nc -P assets/ https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz
|
22 |
+
|
23 |
+
COPY scripts/setup.sh ./
|
24 |
+
RUN apt update && bash setup.sh && rm setup.sh
|
25 |
+
|
26 |
+
COPY requirements/server.txt .
|
27 |
+
RUN pip install --no-cache-dir -r server.txt && rm server.txt
|
28 |
+
COPY whisper_live ./whisper_live
|
29 |
+
COPY scripts/build_whisper_tensorrt.sh .
|
30 |
+
COPY run_server.py .
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faster-whisper
|
2 |
+
numpy
|
3 |
+
websockets
|
4 |
+
pyaudio
|
5 |
+
soundfile
|
6 |
+
torch
|
7 |
+
torchaudio
|
requirements/client.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PyAudio
|
2 |
+
av
|
3 |
+
scipy
|
4 |
+
websocket-client
|
requirements/server.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faster-whisper==1.1.0
|
2 |
+
websockets
|
3 |
+
onnxruntime==1.17.0
|
4 |
+
numba
|
5 |
+
kaldialign
|
6 |
+
soundfile
|
7 |
+
scipy
|
8 |
+
av
|
9 |
+
jiwer
|
10 |
+
evaluate
|
11 |
+
numpy<2
|
12 |
+
openai-whisper==20240930
|
13 |
+
tokenizers==0.20.3
|
14 |
+
|
15 |
+
# openvino
|
16 |
+
librosa
|
17 |
+
openvino
|
18 |
+
openvino-genai
|
19 |
+
openvino-tokenizers
|
20 |
+
optimum
|
21 |
+
optimum-intel
|
run_server.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
parser.add_argument('--port', '-p',
|
7 |
+
type=int,
|
8 |
+
default=9090,
|
9 |
+
help="Websocket port to run the server on.")
|
10 |
+
parser.add_argument('--backend', '-b',
|
11 |
+
type=str,
|
12 |
+
default='faster_whisper',
|
13 |
+
help='Backends from ["tensorrt", "faster_whisper", "openvino"]')
|
14 |
+
parser.add_argument('--faster_whisper_custom_model_path', '-fw',
|
15 |
+
type=str, default=None,
|
16 |
+
help="Custom Faster Whisper Model")
|
17 |
+
parser.add_argument('--trt_model_path', '-trt',
|
18 |
+
type=str,
|
19 |
+
default=None,
|
20 |
+
help='Whisper TensorRT model path')
|
21 |
+
parser.add_argument('--trt_multilingual', '-m',
|
22 |
+
action="store_true",
|
23 |
+
help='Boolean only for TensorRT model. True if multilingual.')
|
24 |
+
parser.add_argument('--trt_py_session',
|
25 |
+
action="store_true",
|
26 |
+
help='Boolean only for TensorRT model. Use python session or cpp session, By default uses Cpp.')
|
27 |
+
parser.add_argument('--omp_num_threads', '-omp',
|
28 |
+
type=int,
|
29 |
+
default=1,
|
30 |
+
help="Number of threads to use for OpenMP")
|
31 |
+
parser.add_argument('--no_single_model', '-nsm',
|
32 |
+
action='store_true',
|
33 |
+
help='Set this if every connection should instantiate its own model. Only relevant for custom model, passed using -trt or -fw.')
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
if args.backend == "tensorrt":
|
37 |
+
if args.trt_model_path is None:
|
38 |
+
raise ValueError("Please Provide a valid tensorrt model path")
|
39 |
+
|
40 |
+
if "OMP_NUM_THREADS" not in os.environ:
|
41 |
+
os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads)
|
42 |
+
|
43 |
+
from whisper_live.server import TranscriptionServer
|
44 |
+
server = TranscriptionServer()
|
45 |
+
server.run(
|
46 |
+
"0.0.0.0",
|
47 |
+
port=args.port,
|
48 |
+
backend=args.backend,
|
49 |
+
faster_whisper_custom_model_path=args.faster_whisper_custom_model_path,
|
50 |
+
whisper_tensorrt_path=args.trt_model_path,
|
51 |
+
trt_multilingual=args.trt_multilingual,
|
52 |
+
trt_py_session=args.trt_py_session,
|
53 |
+
single_model=not args.no_single_model,
|
54 |
+
)
|
scripts/build_whisper_tensorrt.sh
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
download_and_build_model() {
|
4 |
+
local model_name="$1"
|
5 |
+
local model_url=""
|
6 |
+
|
7 |
+
case "$model_name" in
|
8 |
+
"tiny.en")
|
9 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt"
|
10 |
+
;;
|
11 |
+
"tiny")
|
12 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
|
13 |
+
;;
|
14 |
+
"base.en")
|
15 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt"
|
16 |
+
;;
|
17 |
+
"base")
|
18 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
|
19 |
+
;;
|
20 |
+
"small.en")
|
21 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt"
|
22 |
+
;;
|
23 |
+
"small")
|
24 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt"
|
25 |
+
;;
|
26 |
+
"medium.en")
|
27 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt"
|
28 |
+
;;
|
29 |
+
"medium")
|
30 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt"
|
31 |
+
;;
|
32 |
+
"large-v1")
|
33 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt"
|
34 |
+
;;
|
35 |
+
"large-v2")
|
36 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt"
|
37 |
+
;;
|
38 |
+
"large-v3" | "large")
|
39 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt"
|
40 |
+
;;
|
41 |
+
"large-v3-turbo" | "turbo")
|
42 |
+
model_url="https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt"
|
43 |
+
;;
|
44 |
+
*)
|
45 |
+
echo "Invalid model name: $model_name"
|
46 |
+
exit 1
|
47 |
+
;;
|
48 |
+
esac
|
49 |
+
|
50 |
+
if [ "$model_name" == "turbo" ]; then
|
51 |
+
model_name="large-v3-turbo"
|
52 |
+
fi
|
53 |
+
|
54 |
+
local inference_precision="float16"
|
55 |
+
local weight_only_precision="${2:-float16}"
|
56 |
+
local max_beam_width=4
|
57 |
+
local max_batch_size=4
|
58 |
+
|
59 |
+
echo "Downloading $model_name..."
|
60 |
+
# wget --directory-prefix=assets "$model_url"
|
61 |
+
# echo "Download completed: ${model_name}.pt"
|
62 |
+
if [ ! -f "assets/${model_name}.pt" ]; then
|
63 |
+
wget --directory-prefix=assets "$model_url"
|
64 |
+
echo "Download completed: ${model_name}.pt"
|
65 |
+
else
|
66 |
+
echo "${model_name}.pt already exists in assets directory."
|
67 |
+
fi
|
68 |
+
|
69 |
+
local sanitized_model_name="${model_name//./_}"
|
70 |
+
local checkpoint_dir="whisper_${sanitized_model_name}_weights_${weight_only_precision}"
|
71 |
+
local output_dir="whisper_${sanitized_model_name}_${weight_only_precision}"
|
72 |
+
echo "$output_dir"
|
73 |
+
echo "Converting model weights for $model_name..."
|
74 |
+
python3 convert_checkpoint.py \
|
75 |
+
$( [[ "$weight_only_precision" == "int8" || "$weight_only_precision" == "int4" ]] && echo "--use_weight_only --weight_only_precision $weight_only_precision" ) \
|
76 |
+
--output_dir "$checkpoint_dir" --model_name "$model_name"
|
77 |
+
|
78 |
+
echo "Building encoder for $model_name..."
|
79 |
+
trtllm-build \
|
80 |
+
--checkpoint_dir "${checkpoint_dir}/encoder" \
|
81 |
+
--output_dir "${output_dir}/encoder" \
|
82 |
+
--moe_plugin disable \
|
83 |
+
--max_batch_size "$max_batch_size" \
|
84 |
+
--gemm_plugin disable \
|
85 |
+
--bert_attention_plugin "$inference_precision" \
|
86 |
+
--max_input_len 3000 \
|
87 |
+
--max_seq_len 3000
|
88 |
+
|
89 |
+
echo "Building decoder for $model_name..."
|
90 |
+
trtllm-build \
|
91 |
+
--checkpoint_dir "${checkpoint_dir}/decoder" \
|
92 |
+
--output_dir "${output_dir}/decoder" \
|
93 |
+
--moe_plugin disable \
|
94 |
+
--max_beam_width "$max_beam_width" \
|
95 |
+
--max_batch_size "$max_batch_size" \
|
96 |
+
--max_seq_len 225 \
|
97 |
+
--max_input_len 32 \
|
98 |
+
--max_encoder_input_len 3000 \
|
99 |
+
--gemm_plugin "$inference_precision" \
|
100 |
+
--bert_attention_plugin "$inference_precision" \
|
101 |
+
--gpt_attention_plugin "$inference_precision"
|
102 |
+
|
103 |
+
echo "TensorRT LLM engine built for $model_name."
|
104 |
+
echo "========================================="
|
105 |
+
echo "Model is located at: $(pwd)/$output_dir"
|
106 |
+
}
|
107 |
+
|
108 |
+
if [ "$#" -lt 1 ]; then
|
109 |
+
echo "Usage: $0 <path-to-tensorrt-examples-dir> [model-name]"
|
110 |
+
exit 1
|
111 |
+
fi
|
112 |
+
|
113 |
+
tensorrt_examples_dir="$1"
|
114 |
+
model_name="${2:-small.en}"
|
115 |
+
weight_only_precision="${3:-float16}" # Default to float16 if not provided
|
116 |
+
|
117 |
+
cd $tensorrt_examples_dir/whisper
|
118 |
+
pip install --no-deps -r requirements.txt
|
119 |
+
|
120 |
+
download_and_build_model "$model_name" "$weight_only_precision"
|
scripts/setup.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
apt-get install portaudio19-dev wget -y
|
setup.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from setuptools import find_packages, setup
|
3 |
+
from whisper_live.__version__ import __version__
|
4 |
+
|
5 |
+
|
6 |
+
# The directory containing this file
|
7 |
+
HERE = pathlib.Path(__file__).parent
|
8 |
+
|
9 |
+
# The text of the README file
|
10 |
+
README = (HERE / "README.md").read_text()
|
11 |
+
|
12 |
+
# This call to setup() does all the work
|
13 |
+
setup(
|
14 |
+
name="whisper_live",
|
15 |
+
version=__version__,
|
16 |
+
description="A nearly-live implementation of OpenAI's Whisper.",
|
17 |
+
long_description=README,
|
18 |
+
long_description_content_type="text/markdown",
|
19 |
+
include_package_data=True,
|
20 |
+
url="https://github.com/collabora/WhisperLive",
|
21 |
+
author="Collabora Ltd",
|
22 |
+
author_email="[email protected]",
|
23 |
+
license="MIT",
|
24 |
+
classifiers=[
|
25 |
+
"Development Status :: 4 - Beta",
|
26 |
+
"Intended Audience :: Developers",
|
27 |
+
"Intended Audience :: Science/Research",
|
28 |
+
"License :: OSI Approved :: MIT License",
|
29 |
+
"Programming Language :: Python :: 3",
|
30 |
+
"Programming Language :: Python :: 3 :: Only",
|
31 |
+
"Programming Language :: Python :: 3.8",
|
32 |
+
"Programming Language :: Python :: 3.9",
|
33 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
34 |
+
],
|
35 |
+
packages=find_packages(
|
36 |
+
exclude=(
|
37 |
+
"examples",
|
38 |
+
"Audio-Transcription-Chrome",
|
39 |
+
"Audio-Transcription-Firefox",
|
40 |
+
"requirements",
|
41 |
+
"whisper-finetuning"
|
42 |
+
)
|
43 |
+
),
|
44 |
+
install_requires=[
|
45 |
+
"PyAudio",
|
46 |
+
"faster-whisper==1.1.0",
|
47 |
+
"torch",
|
48 |
+
"torchaudio",
|
49 |
+
"websockets",
|
50 |
+
"onnxruntime==1.17.0",
|
51 |
+
"scipy",
|
52 |
+
"websocket-client",
|
53 |
+
"numba",
|
54 |
+
"openai-whisper==20240930",
|
55 |
+
"kaldialign",
|
56 |
+
"soundfile",
|
57 |
+
"tokenizers==0.20.3",
|
58 |
+
"librosa",
|
59 |
+
"numpy==1.26.4",
|
60 |
+
"openvino",
|
61 |
+
"openvino-genai",
|
62 |
+
"openvino-tokenizers",
|
63 |
+
"optimum",
|
64 |
+
"optimum-intel",
|
65 |
+
],
|
66 |
+
python_requires=">=3.9"
|
67 |
+
)
|
tests/__init__.py
ADDED
File without changes
|
tests/test_client.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import scipy
|
4 |
+
import websocket
|
5 |
+
import copy
|
6 |
+
import unittest
|
7 |
+
from unittest.mock import patch, MagicMock
|
8 |
+
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
|
9 |
+
from whisper_live.utils import resample
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
|
13 |
+
class BaseTestCase(unittest.TestCase):
|
14 |
+
@patch('whisper_live.client.websocket.WebSocketApp')
|
15 |
+
@patch('whisper_live.client.pyaudio.PyAudio')
|
16 |
+
def setUp(self, mock_pyaudio, mock_websocket):
|
17 |
+
self.mock_pyaudio_instance = MagicMock()
|
18 |
+
mock_pyaudio.return_value = self.mock_pyaudio_instance
|
19 |
+
self.mock_stream = MagicMock()
|
20 |
+
self.mock_pyaudio_instance.open.return_value = self.mock_stream
|
21 |
+
|
22 |
+
self.mock_ws_app = mock_websocket.return_value
|
23 |
+
self.mock_ws_app.send = MagicMock()
|
24 |
+
|
25 |
+
self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client
|
26 |
+
|
27 |
+
self.mock_pyaudio = mock_pyaudio
|
28 |
+
self.mock_websocket = mock_websocket
|
29 |
+
self.mock_audio_packet = b'\x00\x01\x02\x03'
|
30 |
+
|
31 |
+
def tearDown(self):
|
32 |
+
self.client.close_websocket()
|
33 |
+
self.mock_pyaudio.stop()
|
34 |
+
self.mock_websocket.stop()
|
35 |
+
del self.client
|
36 |
+
|
37 |
+
class TestClientWebSocketCommunication(BaseTestCase):
|
38 |
+
def test_websocket_communication(self):
|
39 |
+
expected_url = 'ws://localhost:9090'
|
40 |
+
self.mock_websocket.assert_called()
|
41 |
+
self.assertEqual(self.mock_websocket.call_args[0][0], expected_url)
|
42 |
+
|
43 |
+
|
44 |
+
class TestClientCallbacks(BaseTestCase):
|
45 |
+
def test_on_open(self):
|
46 |
+
expected_message = json.dumps({
|
47 |
+
"uid": self.client.uid,
|
48 |
+
"language": self.client.language,
|
49 |
+
"task": self.client.task,
|
50 |
+
"model": self.client.model,
|
51 |
+
"use_vad": True,
|
52 |
+
"max_clients": 4,
|
53 |
+
"max_connection_time": 600,
|
54 |
+
"send_last_n_segments": 10,
|
55 |
+
"no_speech_thresh": 0.45,
|
56 |
+
"clip_audio": False,
|
57 |
+
"same_output_threshold": 10,
|
58 |
+
})
|
59 |
+
self.client.on_open(self.mock_ws_app)
|
60 |
+
self.mock_ws_app.send.assert_called_with(expected_message)
|
61 |
+
|
62 |
+
def test_on_message(self):
|
63 |
+
message = json.dumps(
|
64 |
+
{
|
65 |
+
"uid": self.client.uid,
|
66 |
+
"message": "SERVER_READY",
|
67 |
+
"backend": "faster_whisper"
|
68 |
+
}
|
69 |
+
)
|
70 |
+
self.client.on_message(self.mock_ws_app, message)
|
71 |
+
|
72 |
+
message = json.dumps({
|
73 |
+
"uid": self.client.uid,
|
74 |
+
"segments": [
|
75 |
+
{"start": 0, "end": 1, "text": "Test transcript", "completed": True},
|
76 |
+
{"start": 1, "end": 2, "text": "Test transcript 2", "completed": True},
|
77 |
+
{"start": 2, "end": 3, "text": "Test transcript 3", "completed": True}
|
78 |
+
]
|
79 |
+
})
|
80 |
+
self.client.on_message(self.mock_ws_app, message)
|
81 |
+
|
82 |
+
# Assert that the transcript was updated correctly
|
83 |
+
self.assertEqual(len(self.client.transcript), 3)
|
84 |
+
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")
|
85 |
+
|
86 |
+
def test_on_close(self):
|
87 |
+
close_status_code = 1000
|
88 |
+
close_msg = "Normal closure"
|
89 |
+
self.client.on_close(self.mock_ws_app, close_status_code, close_msg)
|
90 |
+
|
91 |
+
self.assertFalse(self.client.recording)
|
92 |
+
self.assertFalse(self.client.server_error)
|
93 |
+
self.assertFalse(self.client.waiting)
|
94 |
+
|
95 |
+
def test_on_error(self):
|
96 |
+
error_message = "Test Error"
|
97 |
+
self.client.on_error(self.mock_ws_app, error_message)
|
98 |
+
|
99 |
+
self.assertTrue(self.client.server_error)
|
100 |
+
self.assertEqual(self.client.error_message, error_message)
|
101 |
+
|
102 |
+
|
103 |
+
class TestAudioResampling(unittest.TestCase):
|
104 |
+
def test_resample_audio(self):
|
105 |
+
original_audio = "assets/jfk.flac"
|
106 |
+
expected_sr = 16000
|
107 |
+
resampled_audio = resample(original_audio, expected_sr)
|
108 |
+
|
109 |
+
sr, _ = scipy.io.wavfile.read(resampled_audio)
|
110 |
+
self.assertEqual(sr, expected_sr)
|
111 |
+
|
112 |
+
os.remove(resampled_audio)
|
113 |
+
|
114 |
+
|
115 |
+
class TestSendingAudioPacket(BaseTestCase):
|
116 |
+
def test_send_packet(self):
|
117 |
+
self.client.send_packet_to_server(self.mock_audio_packet)
|
118 |
+
self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
|
119 |
+
|
120 |
+
class TestTee(BaseTestCase):
|
121 |
+
@patch('whisper_live.client.websocket.WebSocketApp')
|
122 |
+
@patch('whisper_live.client.pyaudio.PyAudio')
|
123 |
+
def setUp(self, mock_audio, mock_websocket):
|
124 |
+
super().setUp()
|
125 |
+
self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt")
|
126 |
+
self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt")
|
127 |
+
# need a separate mock for each websocket
|
128 |
+
self.client3.client_socket = copy.deepcopy(self.client3.client_socket)
|
129 |
+
self.tee = TranscriptionTeeClient([self.client2, self.client3])
|
130 |
+
|
131 |
+
def tearDown(self):
|
132 |
+
self.tee.close_all_clients()
|
133 |
+
del self.tee
|
134 |
+
super().tearDown()
|
135 |
+
|
136 |
+
def test_invalid_constructor(self):
|
137 |
+
with self.assertRaises(Exception) as context:
|
138 |
+
TranscriptionTeeClient([])
|
139 |
+
|
140 |
+
def test_multicast_unconditional(self):
|
141 |
+
self.tee.multicast_packet(self.mock_audio_packet, True)
|
142 |
+
for client in self.tee.clients:
|
143 |
+
client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
|
144 |
+
|
145 |
+
def test_multicast_conditional(self):
|
146 |
+
self.client2.recording = False
|
147 |
+
self.client3.recording = True
|
148 |
+
self.tee.multicast_packet(self.mock_audio_packet, False)
|
149 |
+
self.client2.client_socket.send.assert_not_called()
|
150 |
+
self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
|
151 |
+
|
152 |
+
def test_close_all(self):
|
153 |
+
self.tee.close_all_clients()
|
154 |
+
for client in self.tee.clients:
|
155 |
+
client.client_socket.close.assert_called()
|
156 |
+
|
157 |
+
def test_write_all_srt(self):
|
158 |
+
for client in self.tee.clients:
|
159 |
+
client.server_backend = "faster_whisper"
|
160 |
+
self.tee.write_all_clients_srt()
|
161 |
+
self.assertTrue(Path("transcript.srt").is_file())
|
162 |
+
self.assertTrue(Path("translation.srt").is_file())
|
tests/test_server.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import unittest
|
5 |
+
from unittest import mock
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import jiwer
|
9 |
+
|
10 |
+
from websockets.exceptions import ConnectionClosed
|
11 |
+
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
|
12 |
+
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
|
13 |
+
from whisper.normalizers import EnglishTextNormalizer
|
14 |
+
|
15 |
+
|
16 |
+
class TestTranscriptionServerInitialization(unittest.TestCase):
|
17 |
+
def test_initialization(self):
|
18 |
+
server = TranscriptionServer()
|
19 |
+
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
|
20 |
+
self.assertEqual(server.client_manager.max_clients, 4)
|
21 |
+
self.assertEqual(server.client_manager.max_connection_time, 600)
|
22 |
+
self.assertDictEqual(server.client_manager.clients, {})
|
23 |
+
self.assertDictEqual(server.client_manager.start_times, {})
|
24 |
+
|
25 |
+
|
26 |
+
class TestGetWaitTime(unittest.TestCase):
|
27 |
+
def setUp(self):
|
28 |
+
self.server = TranscriptionServer()
|
29 |
+
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
|
30 |
+
self.server.client_manager.start_times = {
|
31 |
+
'client1': time.time() - 120,
|
32 |
+
'client2': time.time() - 300
|
33 |
+
}
|
34 |
+
self.server.client_manager.max_connection_time = 600
|
35 |
+
|
36 |
+
def test_get_wait_time(self):
|
37 |
+
expected_wait_time = (600 - (time.time() - self.server.client_manager.start_times['client2'])) / 60
|
38 |
+
print(self.server.client_manager.get_wait_time(), expected_wait_time)
|
39 |
+
self.assertAlmostEqual(self.server.client_manager.get_wait_time(), expected_wait_time, places=2)
|
40 |
+
|
41 |
+
|
42 |
+
class TestServerConnection(unittest.TestCase):
|
43 |
+
def setUp(self):
|
44 |
+
self.server = TranscriptionServer()
|
45 |
+
|
46 |
+
@mock.patch('websockets.WebSocketCommonProtocol')
|
47 |
+
def test_connection(self, mock_websocket):
|
48 |
+
mock_websocket.recv.return_value = json.dumps({
|
49 |
+
'uid': 'test_client',
|
50 |
+
'language': 'en',
|
51 |
+
'task': 'transcribe',
|
52 |
+
'model': 'tiny.en'
|
53 |
+
})
|
54 |
+
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
|
55 |
+
|
56 |
+
@mock.patch('websockets.WebSocketCommonProtocol')
|
57 |
+
def test_recv_audio_exception_handling(self, mock_websocket):
|
58 |
+
mock_websocket.recv.side_effect = [json.dumps({
|
59 |
+
'uid': 'test_client',
|
60 |
+
'language': 'en',
|
61 |
+
'task': 'transcribe',
|
62 |
+
'model': 'tiny.en'
|
63 |
+
}), np.array([1, 2, 3]).tobytes()]
|
64 |
+
|
65 |
+
with self.assertLogs(level="ERROR"):
|
66 |
+
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
|
67 |
+
|
68 |
+
self.assertNotIn(mock_websocket, self.server.client_manager.clients)
|
69 |
+
|
70 |
+
|
71 |
+
class TestServerInferenceAccuracy(unittest.TestCase):
|
72 |
+
@classmethod
|
73 |
+
def setUpClass(cls):
|
74 |
+
cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio')
|
75 |
+
cls.mock_pyaudio = cls.mock_pyaudio_patch.start()
|
76 |
+
cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock()
|
77 |
+
|
78 |
+
cls.server_process = subprocess.Popen(["python", "run_server.py"])
|
79 |
+
time.sleep(2)
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def tearDownClass(cls):
|
83 |
+
cls.server_process.terminate()
|
84 |
+
cls.server_process.wait()
|
85 |
+
|
86 |
+
def setUp(self):
|
87 |
+
self.normalizer = EnglishTextNormalizer()
|
88 |
+
|
89 |
+
def check_prediction(self, srt_path):
|
90 |
+
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
|
91 |
+
with open(srt_path, "r") as f:
|
92 |
+
lines = f.readlines()
|
93 |
+
prediction = " ".join([line.strip() for line in lines[2::4]])
|
94 |
+
prediction_normalized = self.normalizer(prediction)
|
95 |
+
gt_normalized = self.normalizer(gt)
|
96 |
+
|
97 |
+
# calculate WER
|
98 |
+
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
|
99 |
+
self.assertLess(wer_score, 0.05)
|
100 |
+
|
101 |
+
def test_inference(self):
|
102 |
+
client = TranscriptionClient(
|
103 |
+
"localhost", "9090", model="base.en", lang="en",
|
104 |
+
)
|
105 |
+
client("assets/jfk.flac")
|
106 |
+
self.check_prediction("output.srt")
|
107 |
+
|
108 |
+
def test_simultaneous_inference(self):
|
109 |
+
client1 = Client(
|
110 |
+
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt")
|
111 |
+
client2 = Client(
|
112 |
+
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt")
|
113 |
+
tee = TranscriptionTeeClient([client1, client2])
|
114 |
+
tee("assets/jfk.flac")
|
115 |
+
self.check_prediction("transcript1.srt")
|
116 |
+
self.check_prediction("transcript2.srt")
|
117 |
+
|
118 |
+
|
119 |
+
class TestExceptionHandling(unittest.TestCase):
|
120 |
+
def setUp(self):
|
121 |
+
self.server = TranscriptionServer()
|
122 |
+
|
123 |
+
@mock.patch('websockets.WebSocketCommonProtocol')
|
124 |
+
def test_connection_closed_exception(self, mock_websocket):
|
125 |
+
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())
|
126 |
+
|
127 |
+
with self.assertLogs(level="INFO") as log:
|
128 |
+
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
|
129 |
+
self.assertTrue(any("Connection closed by client" in message for message in log.output))
|
130 |
+
|
131 |
+
@mock.patch('websockets.WebSocketCommonProtocol')
|
132 |
+
def test_json_decode_exception(self, mock_websocket):
|
133 |
+
mock_websocket.recv.return_value = "invalid json"
|
134 |
+
|
135 |
+
with self.assertLogs(level="ERROR") as log:
|
136 |
+
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
|
137 |
+
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))
|
138 |
+
|
139 |
+
@mock.patch('websockets.WebSocketCommonProtocol')
|
140 |
+
def test_unexpected_exception_handling(self, mock_websocket):
|
141 |
+
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")
|
142 |
+
|
143 |
+
with self.assertLogs(level="ERROR") as log:
|
144 |
+
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
|
145 |
+
for message in log.output:
|
146 |
+
print(message)
|
147 |
+
print()
|
148 |
+
self.assertTrue(any("Unexpected error" in message for message in log.output))
|
tests/test_vad.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from whisper_live.transcriber.tensorrt_utils import load_audio
|
4 |
+
from whisper_live.vad import VoiceActivityDetector
|
5 |
+
|
6 |
+
|
7 |
+
class TestVoiceActivityDetection(unittest.TestCase):
|
8 |
+
def setUp(self):
|
9 |
+
self.vad = VoiceActivityDetector()
|
10 |
+
self.sample_rate = 16000
|
11 |
+
|
12 |
+
def generate_silence(self, duration_seconds):
|
13 |
+
return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32)
|
14 |
+
|
15 |
+
def load_speech_segment(self, filepath):
|
16 |
+
return load_audio(filepath)
|
17 |
+
|
18 |
+
def test_vad_silence_detection(self):
|
19 |
+
silence = self.generate_silence(3)
|
20 |
+
is_speech_present = self.vad(silence.copy())
|
21 |
+
self.assertFalse(is_speech_present, "VAD incorrectly identified silence as speech.")
|
22 |
+
|
23 |
+
def test_vad_speech_detection(self):
|
24 |
+
audio_tensor = load_audio("assets/jfk.flac")
|
25 |
+
is_speech_present = self.vad(audio_tensor)
|
26 |
+
self.assertTrue(is_speech_present, "VAD failed to identify speech segment.")
|
whisper_live/__init__.py
ADDED
File without changes
|
whisper_live/__version__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.7.1"
|
whisper_live/backend/__init__.py
ADDED
File without changes
|
whisper_live/backend/base.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class ServeClientBase(object):
|
9 |
+
RATE = 16000
|
10 |
+
SERVER_READY = "SERVER_READY"
|
11 |
+
DISCONNECT = "DISCONNECT"
|
12 |
+
|
13 |
+
client_uid: str
|
14 |
+
"""A unique identifier for the client."""
|
15 |
+
websocket: object
|
16 |
+
"""The WebSocket connection for the client."""
|
17 |
+
send_last_n_segments: int
|
18 |
+
"""Number of most recent segments to send to the client."""
|
19 |
+
no_speech_thresh: float
|
20 |
+
"""Segments with no speech probability above this threshold will be discarded."""
|
21 |
+
clip_audio: bool
|
22 |
+
"""Whether to clip audio with no valid segments."""
|
23 |
+
same_output_threshold: int
|
24 |
+
"""Number of repeated outputs before considering it as a valid segment."""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
client_uid,
|
29 |
+
websocket,
|
30 |
+
send_last_n_segments=10,
|
31 |
+
no_speech_thresh=0.45,
|
32 |
+
clip_audio=False,
|
33 |
+
same_output_threshold=10,
|
34 |
+
):
|
35 |
+
self.client_uid = client_uid
|
36 |
+
self.websocket = websocket
|
37 |
+
self.send_last_n_segments = send_last_n_segments
|
38 |
+
self.no_speech_thresh = no_speech_thresh
|
39 |
+
self.clip_audio = clip_audio
|
40 |
+
self.same_output_threshold = same_output_threshold
|
41 |
+
|
42 |
+
self.frames = b""
|
43 |
+
self.timestamp_offset = 0.0
|
44 |
+
self.frames_np = None
|
45 |
+
self.frames_offset = 0.0
|
46 |
+
self.text = []
|
47 |
+
self.current_out = ""
|
48 |
+
self.prev_out = ""
|
49 |
+
self.exit = False
|
50 |
+
self.same_output_count = 0
|
51 |
+
self.transcript = []
|
52 |
+
self.end_time_for_same_output = None
|
53 |
+
|
54 |
+
# threading
|
55 |
+
self.lock = threading.Lock()
|
56 |
+
|
57 |
+
def speech_to_text(self):
|
58 |
+
"""
|
59 |
+
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
60 |
+
|
61 |
+
This method continuously receives audio frames, performs real-time transcription, and sends
|
62 |
+
transcribed segments to the client via a WebSocket connection.
|
63 |
+
|
64 |
+
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
65 |
+
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
66 |
+
are sent to the client in real-time, and a history of segments is maintained to provide context.
|
67 |
+
|
68 |
+
Raises:
|
69 |
+
Exception: If there is an issue with audio processing or WebSocket communication.
|
70 |
+
|
71 |
+
"""
|
72 |
+
while True:
|
73 |
+
if self.exit:
|
74 |
+
logging.info("Exiting speech to text thread")
|
75 |
+
break
|
76 |
+
|
77 |
+
if self.frames_np is None:
|
78 |
+
continue
|
79 |
+
|
80 |
+
if self.clip_audio:
|
81 |
+
self.clip_audio_if_no_valid_segment()
|
82 |
+
|
83 |
+
input_bytes, duration = self.get_audio_chunk_for_processing()
|
84 |
+
if duration < 1.0:
|
85 |
+
time.sleep(0.1) # wait for audio chunks to arrive
|
86 |
+
continue
|
87 |
+
try:
|
88 |
+
input_sample = input_bytes.copy()
|
89 |
+
result = self.transcribe_audio(input_sample)
|
90 |
+
|
91 |
+
if result is None or self.language is None:
|
92 |
+
self.timestamp_offset += duration
|
93 |
+
time.sleep(0.25) # wait for voice activity, result is None when no voice activity
|
94 |
+
continue
|
95 |
+
self.handle_transcription_output(result, duration)
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}")
|
99 |
+
time.sleep(0.01)
|
100 |
+
|
101 |
+
def transcribe_audio(self):
|
102 |
+
raise NotImplementedError
|
103 |
+
|
104 |
+
def handle_transcription_output(self, result, duration):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
def format_segment(self, start, end, text, completed=False):
|
108 |
+
"""
|
109 |
+
Formats a transcription segment with precise start and end times alongside the transcribed text.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
start (float): The start time of the transcription segment in seconds.
|
113 |
+
end (float): The end time of the transcription segment in seconds.
|
114 |
+
text (str): The transcribed text corresponding to the segment.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
dict: A dictionary representing the formatted transcription segment, including
|
118 |
+
'start' and 'end' times as strings with three decimal places and the 'text'
|
119 |
+
of the transcription.
|
120 |
+
"""
|
121 |
+
return {
|
122 |
+
'start': "{:.3f}".format(start),
|
123 |
+
'end': "{:.3f}".format(end),
|
124 |
+
'text': text,
|
125 |
+
'completed': completed
|
126 |
+
}
|
127 |
+
|
128 |
+
def add_frames(self, frame_np):
|
129 |
+
"""
|
130 |
+
Add audio frames to the ongoing audio stream buffer.
|
131 |
+
|
132 |
+
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
133 |
+
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
134 |
+
to prevent excessive memory usage.
|
135 |
+
|
136 |
+
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
137 |
+
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
138 |
+
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
142 |
+
|
143 |
+
"""
|
144 |
+
self.lock.acquire()
|
145 |
+
if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
|
146 |
+
self.frames_offset += 30.0
|
147 |
+
self.frames_np = self.frames_np[int(30*self.RATE):]
|
148 |
+
# check timestamp offset(should be >= self.frame_offset)
|
149 |
+
# this basically means that there is no speech as timestamp offset hasnt updated
|
150 |
+
# and is less than frame_offset
|
151 |
+
if self.timestamp_offset < self.frames_offset:
|
152 |
+
self.timestamp_offset = self.frames_offset
|
153 |
+
if self.frames_np is None:
|
154 |
+
self.frames_np = frame_np.copy()
|
155 |
+
else:
|
156 |
+
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
157 |
+
self.lock.release()
|
158 |
+
|
159 |
+
def clip_audio_if_no_valid_segment(self):
|
160 |
+
"""
|
161 |
+
Update the timestamp offset based on audio buffer status.
|
162 |
+
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
163 |
+
no valid segment for the last 30 seconds from whisper
|
164 |
+
"""
|
165 |
+
with self.lock:
|
166 |
+
if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
|
167 |
+
duration = self.frames_np.shape[0] / self.RATE
|
168 |
+
self.timestamp_offset = self.frames_offset + duration - 5
|
169 |
+
|
170 |
+
def get_audio_chunk_for_processing(self):
|
171 |
+
"""
|
172 |
+
Retrieves the next chunk of audio data for processing based on the current offsets.
|
173 |
+
|
174 |
+
Calculates which part of the audio data should be processed next, based on
|
175 |
+
the difference between the current timestamp offset and the frame's offset, scaled by
|
176 |
+
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
177 |
+
duration in seconds.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
tuple: A tuple containing:
|
181 |
+
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
182 |
+
- duration (float): The duration of the audio chunk in seconds.
|
183 |
+
"""
|
184 |
+
with self.lock:
|
185 |
+
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
186 |
+
input_bytes = self.frames_np[int(samples_take):].copy()
|
187 |
+
duration = input_bytes.shape[0] / self.RATE
|
188 |
+
return input_bytes, duration
|
189 |
+
|
190 |
+
def prepare_segments(self, last_segment=None):
|
191 |
+
"""
|
192 |
+
Prepares the segments of transcribed text to be sent to the client.
|
193 |
+
|
194 |
+
This method compiles the recent segments of transcribed text, ensuring that only the
|
195 |
+
specified number of the most recent segments are included. It also appends the most
|
196 |
+
recent segment of text if provided (which is considered incomplete because of the possibility
|
197 |
+
of the last word being truncated in the audio chunk).
|
198 |
+
|
199 |
+
Args:
|
200 |
+
last_segment (str, optional): The most recent segment of transcribed text to be added
|
201 |
+
to the list of segments. Defaults to None.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
list: A list of transcribed text segments to be sent to the client.
|
205 |
+
"""
|
206 |
+
segments = []
|
207 |
+
if len(self.transcript) >= self.send_last_n_segments:
|
208 |
+
segments = self.transcript[-self.send_last_n_segments:].copy()
|
209 |
+
else:
|
210 |
+
segments = self.transcript.copy()
|
211 |
+
if last_segment is not None:
|
212 |
+
segments = segments + [last_segment]
|
213 |
+
return segments
|
214 |
+
|
215 |
+
def get_audio_chunk_duration(self, input_bytes):
|
216 |
+
"""
|
217 |
+
Calculates the duration of the provided audio chunk.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
float: The duration of the audio chunk in seconds.
|
224 |
+
"""
|
225 |
+
return input_bytes.shape[0] / self.RATE
|
226 |
+
|
227 |
+
def send_transcription_to_client(self, segments):
|
228 |
+
"""
|
229 |
+
Sends the specified transcription segments to the client over the websocket connection.
|
230 |
+
|
231 |
+
This method formats the transcription segments into a JSON object and attempts to send
|
232 |
+
this object to the client. If an error occurs during the send operation, it logs the error.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
segments (list): A list of transcription segments to be sent to the client.
|
236 |
+
"""
|
237 |
+
try:
|
238 |
+
self.websocket.send(
|
239 |
+
json.dumps({
|
240 |
+
"uid": self.client_uid,
|
241 |
+
"segments": segments,
|
242 |
+
})
|
243 |
+
)
|
244 |
+
except Exception as e:
|
245 |
+
logging.error(f"[ERROR]: Sending data to client: {e}")
|
246 |
+
|
247 |
+
def disconnect(self):
|
248 |
+
"""
|
249 |
+
Notify the client of disconnection and send a disconnect message.
|
250 |
+
|
251 |
+
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
252 |
+
that the transcription service is disconnecting gracefully.
|
253 |
+
|
254 |
+
"""
|
255 |
+
self.websocket.send(json.dumps({
|
256 |
+
"uid": self.client_uid,
|
257 |
+
"message": self.DISCONNECT
|
258 |
+
}))
|
259 |
+
|
260 |
+
def cleanup(self):
|
261 |
+
"""
|
262 |
+
Perform cleanup tasks before exiting the transcription service.
|
263 |
+
|
264 |
+
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
265 |
+
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
266 |
+
associated with the transcription process.
|
267 |
+
|
268 |
+
"""
|
269 |
+
logging.info("Cleaning up.")
|
270 |
+
self.exit = True
|
271 |
+
|
272 |
+
def get_segment_no_speech_prob(self, segment):
|
273 |
+
return getattr(segment, "no_speech_prob", 0)
|
274 |
+
|
275 |
+
def get_segment_start(self, segment):
|
276 |
+
return getattr(segment, "start", getattr(segment, "start_ts", 0))
|
277 |
+
|
278 |
+
def get_segment_end(self, segment):
|
279 |
+
return getattr(segment, "end", getattr(segment, "end_ts", 0))
|
280 |
+
|
281 |
+
def update_segments(self, segments, duration):
|
282 |
+
"""
|
283 |
+
Processes the segments from Whisper and updates the transcript.
|
284 |
+
Uses helper methods to account for differences between backends.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
segments (list): List of segments returned by the transcriber.
|
288 |
+
duration (float): Duration of the current audio chunk.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
dict or None: The last processed segment (if any).
|
292 |
+
"""
|
293 |
+
offset = None
|
294 |
+
self.current_out = ''
|
295 |
+
last_segment = None
|
296 |
+
|
297 |
+
# Process complete segments only if there are more than one
|
298 |
+
# and if the last segment's no_speech_prob is below the threshold.
|
299 |
+
if len(segments) > 1 and self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh:
|
300 |
+
for s in segments[:-1]:
|
301 |
+
text_ = s.text
|
302 |
+
self.text.append(text_)
|
303 |
+
with self.lock:
|
304 |
+
start = self.timestamp_offset + self.get_segment_start(s)
|
305 |
+
end = self.timestamp_offset + min(duration, self.get_segment_end(s))
|
306 |
+
if start >= end:
|
307 |
+
continue
|
308 |
+
if self.get_segment_no_speech_prob(s) > self.no_speech_thresh:
|
309 |
+
continue
|
310 |
+
self.transcript.append(self.format_segment(start, end, text_, completed=True))
|
311 |
+
offset = min(duration, self.get_segment_end(s))
|
312 |
+
|
313 |
+
# Process the last segment if its no_speech_prob is acceptable.
|
314 |
+
if self.get_segment_no_speech_prob(segments[-1]) <= self.no_speech_thresh:
|
315 |
+
self.current_out += segments[-1].text
|
316 |
+
with self.lock:
|
317 |
+
last_segment = self.format_segment(
|
318 |
+
self.timestamp_offset + self.get_segment_start(segments[-1]),
|
319 |
+
self.timestamp_offset + min(duration, self.get_segment_end(segments[-1])),
|
320 |
+
self.current_out,
|
321 |
+
completed=False
|
322 |
+
)
|
323 |
+
|
324 |
+
# Handle repeated output logic.
|
325 |
+
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
|
326 |
+
self.same_output_count += 1
|
327 |
+
|
328 |
+
# if we remove the audio because of same output on the nth reptition we might remove the
|
329 |
+
# audio thats not yet transcribed so, capturing the time when it was repeated for the first time
|
330 |
+
if self.end_time_for_same_output is None:
|
331 |
+
self.end_time_for_same_output = self.get_segment_end(segments[-1])
|
332 |
+
time.sleep(0.1) # wait briefly for any new voice activity
|
333 |
+
else:
|
334 |
+
self.same_output_count = 0
|
335 |
+
self.end_time_for_same_output = None
|
336 |
+
|
337 |
+
# If the same incomplete segment is repeated too many times,
|
338 |
+
# append it to the transcript and update the offset.
|
339 |
+
if self.same_output_count > self.same_output_threshold:
|
340 |
+
if not self.text or self.text[-1].strip().lower() != self.current_out.strip().lower():
|
341 |
+
self.text.append(self.current_out)
|
342 |
+
with self.lock:
|
343 |
+
self.transcript.append(self.format_segment(
|
344 |
+
self.timestamp_offset,
|
345 |
+
self.timestamp_offset + min(duration, self.end_time_for_same_output),
|
346 |
+
self.current_out,
|
347 |
+
completed=True
|
348 |
+
))
|
349 |
+
self.current_out = ''
|
350 |
+
offset = min(duration, self.end_time_for_same_output)
|
351 |
+
self.same_output_count = 0
|
352 |
+
last_segment = None
|
353 |
+
self.end_time_for_same_output = None
|
354 |
+
else:
|
355 |
+
self.prev_out = self.current_out
|
356 |
+
|
357 |
+
if offset is not None:
|
358 |
+
with self.lock:
|
359 |
+
self.timestamp_offset += offset
|
360 |
+
|
361 |
+
return last_segment
|
whisper_live/backend/faster_whisper_backend.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from whisper_live.transcriber.transcriber_faster_whisper import WhisperModel
|
8 |
+
from whisper_live.backend.base import ServeClientBase
|
9 |
+
|
10 |
+
|
11 |
+
class ServeClientFasterWhisper(ServeClientBase):
|
12 |
+
SINGLE_MODEL = None
|
13 |
+
SINGLE_MODEL_LOCK = threading.Lock()
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
websocket,
|
18 |
+
task="transcribe",
|
19 |
+
device=None,
|
20 |
+
language=None,
|
21 |
+
client_uid=None,
|
22 |
+
model="small.en",
|
23 |
+
initial_prompt=None,
|
24 |
+
vad_parameters=None,
|
25 |
+
use_vad=True,
|
26 |
+
single_model=False,
|
27 |
+
send_last_n_segments=10,
|
28 |
+
no_speech_thresh=0.45,
|
29 |
+
clip_audio=False,
|
30 |
+
same_output_threshold=10,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Initialize a ServeClient instance.
|
34 |
+
The Whisper model is initialized based on the client's language and device availability.
|
35 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
36 |
+
to the client to indicate that the server is ready.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
40 |
+
task (str, optional): The task type, e.g., "transcribe". Defaults to "transcribe".
|
41 |
+
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
42 |
+
language (str, optional): The language for transcription. Defaults to None.
|
43 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
44 |
+
model (str, optional): The whisper model size. Defaults to 'small.en'
|
45 |
+
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
|
46 |
+
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
47 |
+
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
|
48 |
+
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
|
49 |
+
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
|
50 |
+
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
|
51 |
+
|
52 |
+
"""
|
53 |
+
super().__init__(
|
54 |
+
client_uid,
|
55 |
+
websocket,
|
56 |
+
send_last_n_segments,
|
57 |
+
no_speech_thresh,
|
58 |
+
clip_audio,
|
59 |
+
same_output_threshold,
|
60 |
+
)
|
61 |
+
self.model_sizes = [
|
62 |
+
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
|
63 |
+
"medium", "medium.en", "large-v2", "large-v3", "distil-small.en",
|
64 |
+
"distil-medium.en", "distil-large-v2", "distil-large-v3",
|
65 |
+
"large-v3-turbo", "turbo"
|
66 |
+
]
|
67 |
+
|
68 |
+
self.model_size_or_path = model
|
69 |
+
self.language = "en" if self.model_size_or_path.endswith("en") else language
|
70 |
+
self.task = task
|
71 |
+
self.initial_prompt = initial_prompt
|
72 |
+
self.vad_parameters = vad_parameters or {"onset": 0.5}
|
73 |
+
|
74 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
75 |
+
if device == "cuda":
|
76 |
+
major, _ = torch.cuda.get_device_capability(device)
|
77 |
+
self.compute_type = "float16" if major >= 7 else "float32"
|
78 |
+
else:
|
79 |
+
self.compute_type = "int8"
|
80 |
+
|
81 |
+
if self.model_size_or_path is None:
|
82 |
+
return
|
83 |
+
logging.info(f"Using Device={device} with precision {self.compute_type}")
|
84 |
+
|
85 |
+
try:
|
86 |
+
if single_model:
|
87 |
+
if ServeClientFasterWhisper.SINGLE_MODEL is None:
|
88 |
+
self.create_model(device)
|
89 |
+
ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
|
90 |
+
else:
|
91 |
+
self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
|
92 |
+
else:
|
93 |
+
self.create_model(device)
|
94 |
+
except Exception as e:
|
95 |
+
logging.error(f"Failed to load model: {e}")
|
96 |
+
self.websocket.send(json.dumps({
|
97 |
+
"uid": self.client_uid,
|
98 |
+
"status": "ERROR",
|
99 |
+
"message": f"Failed to load model: {str(self.model_size_or_path)}"
|
100 |
+
}))
|
101 |
+
self.websocket.close()
|
102 |
+
return
|
103 |
+
|
104 |
+
self.use_vad = use_vad
|
105 |
+
|
106 |
+
# threading
|
107 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
108 |
+
self.trans_thread.start()
|
109 |
+
self.websocket.send(
|
110 |
+
json.dumps(
|
111 |
+
{
|
112 |
+
"uid": self.client_uid,
|
113 |
+
"message": self.SERVER_READY,
|
114 |
+
"backend": "faster_whisper"
|
115 |
+
}
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
def create_model(self, device):
|
120 |
+
"""
|
121 |
+
Instantiates a new model, sets it as the transcriber.
|
122 |
+
"""
|
123 |
+
self.transcriber = WhisperModel(
|
124 |
+
self.model_size_or_path,
|
125 |
+
device=device,
|
126 |
+
compute_type=self.compute_type,
|
127 |
+
local_files_only=False,
|
128 |
+
)
|
129 |
+
|
130 |
+
def check_valid_model(self, model_size):
|
131 |
+
"""
|
132 |
+
Check if it's a valid whisper model size.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
model_size (str): The name of the model size to check.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
str: The model size if valid, None otherwise.
|
139 |
+
"""
|
140 |
+
if model_size not in self.model_sizes:
|
141 |
+
self.websocket.send(
|
142 |
+
json.dumps(
|
143 |
+
{
|
144 |
+
"uid": self.client_uid,
|
145 |
+
"status": "ERROR",
|
146 |
+
"message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}"
|
147 |
+
}
|
148 |
+
)
|
149 |
+
)
|
150 |
+
return None
|
151 |
+
return model_size
|
152 |
+
|
153 |
+
def set_language(self, info):
|
154 |
+
"""
|
155 |
+
Updates the language attribute based on the detected language information.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
info (object): An object containing the detected language and its probability. This object
|
159 |
+
must have at least two attributes: `language`, a string indicating the detected
|
160 |
+
language, and `language_probability`, a float representing the confidence level
|
161 |
+
of the language detection.
|
162 |
+
"""
|
163 |
+
if info.language_probability > 0.5:
|
164 |
+
self.language = info.language
|
165 |
+
logging.info(f"Detected language {self.language} with probability {info.language_probability}")
|
166 |
+
self.websocket.send(json.dumps(
|
167 |
+
{"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
|
168 |
+
|
169 |
+
def transcribe_audio(self, input_sample):
|
170 |
+
"""
|
171 |
+
Transcribes the provided audio sample using the configured transcriber instance.
|
172 |
+
|
173 |
+
If the language has not been set, it updates the session's language based on the transcription
|
174 |
+
information.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
|
178 |
+
array representing the audio data.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
The transcription result from the transcriber. The exact format of this result
|
182 |
+
depends on the implementation of the `transcriber.transcribe` method but typically
|
183 |
+
includes the transcribed text.
|
184 |
+
"""
|
185 |
+
if ServeClientFasterWhisper.SINGLE_MODEL:
|
186 |
+
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
|
187 |
+
result, info = self.transcriber.transcribe(
|
188 |
+
input_sample,
|
189 |
+
initial_prompt=self.initial_prompt,
|
190 |
+
language=self.language,
|
191 |
+
task=self.task,
|
192 |
+
vad_filter=self.use_vad,
|
193 |
+
vad_parameters=self.vad_parameters if self.use_vad else None)
|
194 |
+
if ServeClientFasterWhisper.SINGLE_MODEL:
|
195 |
+
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()
|
196 |
+
|
197 |
+
if self.language is None and info is not None:
|
198 |
+
self.set_language(info)
|
199 |
+
return result
|
200 |
+
|
201 |
+
def handle_transcription_output(self, result, duration):
|
202 |
+
"""
|
203 |
+
Handle the transcription output, updating the transcript and sending data to the client.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
result (str): The result from whisper inference i.e. the list of segments.
|
207 |
+
duration (float): Duration of the transcribed audio chunk.
|
208 |
+
"""
|
209 |
+
segments = []
|
210 |
+
if len(result):
|
211 |
+
self.t_start = None
|
212 |
+
last_segment = self.update_segments(result, duration)
|
213 |
+
segments = self.prepare_segments(last_segment)
|
214 |
+
|
215 |
+
if len(segments):
|
216 |
+
self.send_transcription_to_client(segments)
|
whisper_live/backend/openvino_backend.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
|
6 |
+
from openvino import Core
|
7 |
+
from whisper_live.backend.base import ServeClientBase
|
8 |
+
from whisper_live.transcriber.transcriber_openvino import WhisperOpenVINO
|
9 |
+
|
10 |
+
|
11 |
+
class ServeClientOpenVINO(ServeClientBase):
|
12 |
+
SINGLE_MODEL = None
|
13 |
+
SINGLE_MODEL_LOCK = threading.Lock()
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
websocket,
|
18 |
+
task="transcribe",
|
19 |
+
device=None,
|
20 |
+
language=None,
|
21 |
+
client_uid=None,
|
22 |
+
model="small.en",
|
23 |
+
initial_prompt=None,
|
24 |
+
vad_parameters=None,
|
25 |
+
use_vad=True,
|
26 |
+
single_model=False,
|
27 |
+
send_last_n_segments=10,
|
28 |
+
no_speech_thresh=0.45,
|
29 |
+
clip_audio=False,
|
30 |
+
same_output_threshold=10,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Initialize a ServeClient instance.
|
34 |
+
The Whisper model is initialized based on the client's language and device availability.
|
35 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
36 |
+
to the client to indicate that the server is ready.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
40 |
+
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
41 |
+
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
42 |
+
language (str, optional): The language for transcription. Defaults to None.
|
43 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
44 |
+
model (str, optional): Huggingface model_id for a valid OpenVINO model.
|
45 |
+
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
|
46 |
+
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
47 |
+
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
|
48 |
+
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
|
49 |
+
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
|
50 |
+
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
|
51 |
+
"""
|
52 |
+
super().__init__(
|
53 |
+
client_uid,
|
54 |
+
websocket,
|
55 |
+
send_last_n_segments,
|
56 |
+
no_speech_thresh,
|
57 |
+
clip_audio,
|
58 |
+
same_output_threshold,
|
59 |
+
)
|
60 |
+
self.language = "en" if language is None else language
|
61 |
+
if not self.language.startswith("<|"):
|
62 |
+
self.language = f"<|{self.language}|>"
|
63 |
+
|
64 |
+
self.task = "transcribe" if task is None else task
|
65 |
+
|
66 |
+
self.clip_audio = True
|
67 |
+
|
68 |
+
core = Core()
|
69 |
+
available_devices = core.available_devices
|
70 |
+
if 'GPU' in available_devices:
|
71 |
+
selected_device = 'GPU'
|
72 |
+
else:
|
73 |
+
gpu_devices = [d for d in available_devices if d.startswith('GPU')]
|
74 |
+
selected_device = gpu_devices[0] if gpu_devices else 'CPU'
|
75 |
+
self.device = selected_device
|
76 |
+
|
77 |
+
|
78 |
+
if single_model:
|
79 |
+
if ServeClientOpenVINO.SINGLE_MODEL is None:
|
80 |
+
self.create_model(model)
|
81 |
+
ServeClientOpenVINO.SINGLE_MODEL = self.transcriber
|
82 |
+
else:
|
83 |
+
self.transcriber = ServeClientOpenVINO.SINGLE_MODEL
|
84 |
+
else:
|
85 |
+
self.create_model(model)
|
86 |
+
|
87 |
+
# threading
|
88 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
89 |
+
self.trans_thread.start()
|
90 |
+
|
91 |
+
self.websocket.send(json.dumps({
|
92 |
+
"uid": self.client_uid,
|
93 |
+
"message": self.SERVER_READY,
|
94 |
+
"backend": "openvino"
|
95 |
+
}))
|
96 |
+
logging.info(f"Using OpenVINO device: {self.device}")
|
97 |
+
logging.info(f"Running OpenVINO backend with language: {self.language} and task: {self.task}")
|
98 |
+
|
99 |
+
def create_model(self, model_id):
|
100 |
+
"""
|
101 |
+
Instantiates a new model, sets it as the transcriber.
|
102 |
+
"""
|
103 |
+
self.transcriber = WhisperOpenVINO(
|
104 |
+
model_id,
|
105 |
+
device=self.device,
|
106 |
+
language=self.language,
|
107 |
+
task=self.task
|
108 |
+
)
|
109 |
+
|
110 |
+
def transcribe_audio(self, input_sample):
|
111 |
+
"""
|
112 |
+
Transcribes the provided audio sample using the configured transcriber instance.
|
113 |
+
|
114 |
+
If the language has not been set, it updates the session's language based on the transcription
|
115 |
+
information.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
|
119 |
+
array representing the audio data.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
The transcription result from the transcriber. The exact format of this result
|
123 |
+
depends on the implementation of the `transcriber.transcribe` method but typically
|
124 |
+
includes the transcribed text.
|
125 |
+
"""
|
126 |
+
if ServeClientOpenVINO.SINGLE_MODEL:
|
127 |
+
ServeClientOpenVINO.SINGLE_MODEL_LOCK.acquire()
|
128 |
+
result = self.transcriber.transcribe(input_sample)
|
129 |
+
if ServeClientOpenVINO.SINGLE_MODEL:
|
130 |
+
ServeClientOpenVINO.SINGLE_MODEL_LOCK.release()
|
131 |
+
return result
|
132 |
+
|
133 |
+
def handle_transcription_output(self, result, duration):
|
134 |
+
"""
|
135 |
+
Handle the transcription output, updating the transcript and sending data to the client.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
result (str): The result from whisper inference i.e. the list of segments.
|
139 |
+
duration (float): Duration of the transcribed audio chunk.
|
140 |
+
"""
|
141 |
+
segments = []
|
142 |
+
if len(result):
|
143 |
+
self.t_start = None
|
144 |
+
last_segment = self.update_segments(result, duration)
|
145 |
+
segments = self.prepare_segments(last_segment)
|
146 |
+
|
147 |
+
if len(segments):
|
148 |
+
self.send_transcription_to_client(segments)
|
whisper_live/backend/trt_backend.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
|
6 |
+
from whisper_live.backend.base import ServeClientBase
|
7 |
+
from whisper_live.transcriber.transcriber_tensorrt import WhisperTRTLLM
|
8 |
+
|
9 |
+
|
10 |
+
class ServeClientTensorRT(ServeClientBase):
|
11 |
+
SINGLE_MODEL = None
|
12 |
+
SINGLE_MODEL_LOCK = threading.Lock()
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
websocket,
|
17 |
+
task="transcribe",
|
18 |
+
multilingual=False,
|
19 |
+
language=None,
|
20 |
+
client_uid=None,
|
21 |
+
model=None,
|
22 |
+
single_model=False,
|
23 |
+
use_py_session=False,
|
24 |
+
max_new_tokens=225,
|
25 |
+
send_last_n_segments=10,
|
26 |
+
no_speech_thresh=0.45,
|
27 |
+
clip_audio=False,
|
28 |
+
same_output_threshold=10,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initialize a ServeClient instance.
|
32 |
+
The Whisper model is initialized based on the client's language and device availability.
|
33 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
34 |
+
to the client to indicate that the server is ready.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
38 |
+
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
39 |
+
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
40 |
+
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
|
41 |
+
language (str, optional): The language for transcription. Defaults to None.
|
42 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
43 |
+
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
44 |
+
use_py_session (bool, optional): Use python session or cpp session. Defaults to Cpp Session.
|
45 |
+
max_new_tokens (int, optional): Max number of tokens to generate.
|
46 |
+
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
|
47 |
+
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
|
48 |
+
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
|
49 |
+
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
|
50 |
+
"""
|
51 |
+
super().__init__(
|
52 |
+
client_uid,
|
53 |
+
websocket,
|
54 |
+
send_last_n_segments,
|
55 |
+
no_speech_thresh,
|
56 |
+
clip_audio,
|
57 |
+
same_output_threshold,
|
58 |
+
)
|
59 |
+
|
60 |
+
self.language = language if multilingual else "en"
|
61 |
+
self.task = task
|
62 |
+
self.eos = False
|
63 |
+
self.max_new_tokens = max_new_tokens
|
64 |
+
|
65 |
+
if single_model:
|
66 |
+
if ServeClientTensorRT.SINGLE_MODEL is None:
|
67 |
+
self.create_model(model, multilingual, use_py_session=use_py_session)
|
68 |
+
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
|
69 |
+
else:
|
70 |
+
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
|
71 |
+
else:
|
72 |
+
self.create_model(model, multilingual, use_py_session=use_py_session)
|
73 |
+
|
74 |
+
# threading
|
75 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
76 |
+
self.trans_thread.start()
|
77 |
+
|
78 |
+
self.websocket.send(json.dumps({
|
79 |
+
"uid": self.client_uid,
|
80 |
+
"message": self.SERVER_READY,
|
81 |
+
"backend": "tensorrt"
|
82 |
+
}))
|
83 |
+
|
84 |
+
def create_model(self, model, multilingual, warmup=True, use_py_session=False):
|
85 |
+
"""
|
86 |
+
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
87 |
+
"""
|
88 |
+
self.transcriber = WhisperTRTLLM(
|
89 |
+
model,
|
90 |
+
assets_dir="assets",
|
91 |
+
device="cuda",
|
92 |
+
is_multilingual=multilingual,
|
93 |
+
language=self.language,
|
94 |
+
task=self.task,
|
95 |
+
use_py_session=use_py_session,
|
96 |
+
max_output_len=self.max_new_tokens,
|
97 |
+
)
|
98 |
+
if warmup:
|
99 |
+
self.warmup()
|
100 |
+
|
101 |
+
def warmup(self, warmup_steps=10):
|
102 |
+
"""
|
103 |
+
Warmup TensorRT since first few inferences are slow.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
warmup_steps (int): Number of steps to warm up the model for.
|
107 |
+
"""
|
108 |
+
logging.info("[INFO:] Warming up TensorRT engine..")
|
109 |
+
mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac")
|
110 |
+
for i in range(warmup_steps):
|
111 |
+
self.transcriber.transcribe(mel)
|
112 |
+
|
113 |
+
def set_eos(self, eos):
|
114 |
+
"""
|
115 |
+
Sets the End of Speech (EOS) flag.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
eos (bool): The value to set for the EOS flag.
|
119 |
+
"""
|
120 |
+
self.lock.acquire()
|
121 |
+
self.eos = eos
|
122 |
+
self.lock.release()
|
123 |
+
|
124 |
+
def handle_transcription_output(self, last_segment, duration):
|
125 |
+
"""
|
126 |
+
Handle the transcription output, updating the transcript and sending data to the client.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
130 |
+
of the possibility of word being truncated.
|
131 |
+
duration (float): Duration of the transcribed audio chunk.
|
132 |
+
"""
|
133 |
+
segments = self.prepare_segments({"text": last_segment})
|
134 |
+
self.send_transcription_to_client(segments)
|
135 |
+
if self.eos:
|
136 |
+
self.update_timestamp_offset(last_segment, duration)
|
137 |
+
|
138 |
+
def transcribe_audio(self, input_bytes):
|
139 |
+
"""
|
140 |
+
Transcribe the audio chunk and send the results to the client.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
input_bytes (np.array): The audio chunk to transcribe.
|
144 |
+
"""
|
145 |
+
if ServeClientTensorRT.SINGLE_MODEL:
|
146 |
+
ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
|
147 |
+
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
148 |
+
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
|
149 |
+
last_segment = self.transcriber.transcribe(
|
150 |
+
mel,
|
151 |
+
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>",
|
152 |
+
)
|
153 |
+
if ServeClientTensorRT.SINGLE_MODEL:
|
154 |
+
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
|
155 |
+
if last_segment:
|
156 |
+
self.handle_transcription_output(last_segment, duration)
|
157 |
+
|
158 |
+
def update_timestamp_offset(self, last_segment, duration):
|
159 |
+
"""
|
160 |
+
Update timestamp offset and transcript.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
last_segment (str): Last transcribed audio from the whisper model.
|
164 |
+
duration (float): Duration of the last audio chunk.
|
165 |
+
"""
|
166 |
+
if not len(self.transcript):
|
167 |
+
self.transcript.append({"text": last_segment + " "})
|
168 |
+
elif self.transcript[-1]["text"].strip() != last_segment:
|
169 |
+
self.transcript.append({"text": last_segment + " "})
|
170 |
+
|
171 |
+
with self.lock:
|
172 |
+
self.timestamp_offset += duration
|
173 |
+
|
174 |
+
def speech_to_text(self):
|
175 |
+
"""
|
176 |
+
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
177 |
+
|
178 |
+
This method continuously receives audio frames, performs real-time transcription, and sends
|
179 |
+
transcribed segments to the client via a WebSocket connection.
|
180 |
+
|
181 |
+
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
182 |
+
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
183 |
+
are sent to the client in real-time, and a history of segments is maintained to provide context.
|
184 |
+
|
185 |
+
Raises:
|
186 |
+
Exception: If there is an issue with audio processing or WebSocket communication.
|
187 |
+
|
188 |
+
"""
|
189 |
+
while True:
|
190 |
+
if self.exit:
|
191 |
+
logging.info("Exiting speech to text thread")
|
192 |
+
break
|
193 |
+
|
194 |
+
if self.frames_np is None:
|
195 |
+
time.sleep(0.02) # wait for any audio to arrive
|
196 |
+
continue
|
197 |
+
|
198 |
+
self.clip_audio_if_no_valid_segment()
|
199 |
+
|
200 |
+
input_bytes, duration = self.get_audio_chunk_for_processing()
|
201 |
+
if duration < 0.4:
|
202 |
+
continue
|
203 |
+
|
204 |
+
try:
|
205 |
+
input_sample = input_bytes.copy()
|
206 |
+
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}")
|
207 |
+
self.transcribe_audio(input_sample)
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
logging.error(f"[ERROR]: {e}")
|
whisper_live/client.py
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import wave
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import pyaudio
|
8 |
+
import threading
|
9 |
+
import json
|
10 |
+
import websocket
|
11 |
+
import uuid
|
12 |
+
import time
|
13 |
+
import av
|
14 |
+
import whisper_live.utils as utils
|
15 |
+
|
16 |
+
|
17 |
+
class Client:
|
18 |
+
"""
|
19 |
+
Handles communication with a server using WebSocket.
|
20 |
+
"""
|
21 |
+
INSTANCES = {}
|
22 |
+
END_OF_AUDIO = "END_OF_AUDIO"
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
host=None,
|
27 |
+
port=None,
|
28 |
+
lang=None,
|
29 |
+
translate=False,
|
30 |
+
model="small",
|
31 |
+
srt_file_path="output.srt",
|
32 |
+
use_vad=True,
|
33 |
+
use_wss=False,
|
34 |
+
log_transcription=True,
|
35 |
+
max_clients=4,
|
36 |
+
max_connection_time=600,
|
37 |
+
send_last_n_segments=10,
|
38 |
+
no_speech_thresh=0.45,
|
39 |
+
clip_audio=False,
|
40 |
+
same_output_threshold=10,
|
41 |
+
transcription_callback=None,
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Initializes a Client instance for audio recording and streaming to a server.
|
45 |
+
|
46 |
+
If host and port are not provided, the WebSocket connection will not be established.
|
47 |
+
When translate is True, the task will be set to "translate" instead of "transcribe".
|
48 |
+
he audio recording starts immediately upon initialization.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
host (str): The hostname or IP address of the server.
|
52 |
+
port (int): The port number for the WebSocket server.
|
53 |
+
lang (str, optional): The selected language for transcription. Default is None.
|
54 |
+
translate (bool, optional): Specifies if the task is translation. Default is False.
|
55 |
+
model (str, optional): The whisper model to use (e.g., "small", "medium", "large"). Default is "small".
|
56 |
+
srt_file_path (str, optional): The file path to save the output SRT file. Default is "output.srt".
|
57 |
+
use_vad (bool, optional): Whether to enable voice activity detection. Default is True.
|
58 |
+
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
59 |
+
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
60 |
+
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
61 |
+
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
|
62 |
+
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
|
63 |
+
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
|
64 |
+
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
|
65 |
+
transcription_callback (callable, optional): A callback function to handle transcription results. Default is None.
|
66 |
+
"""
|
67 |
+
self.recording = False
|
68 |
+
self.task = "transcribe"
|
69 |
+
self.uid = str(uuid.uuid4())
|
70 |
+
self.waiting = False
|
71 |
+
self.last_response_received = None
|
72 |
+
self.disconnect_if_no_response_for = 15
|
73 |
+
self.language = lang
|
74 |
+
self.model = model
|
75 |
+
self.server_error = False
|
76 |
+
self.srt_file_path = srt_file_path
|
77 |
+
self.use_vad = use_vad
|
78 |
+
self.use_wss = use_wss
|
79 |
+
self.last_segment = None
|
80 |
+
self.last_received_segment = None
|
81 |
+
self.log_transcription = log_transcription
|
82 |
+
self.max_clients = max_clients
|
83 |
+
self.max_connection_time = max_connection_time
|
84 |
+
self.send_last_n_segments = send_last_n_segments
|
85 |
+
self.no_speech_thresh = no_speech_thresh
|
86 |
+
self.clip_audio = clip_audio
|
87 |
+
self.same_output_threshold = same_output_threshold
|
88 |
+
self.transcription_callback = transcription_callback
|
89 |
+
|
90 |
+
if translate:
|
91 |
+
self.task = "translate"
|
92 |
+
|
93 |
+
self.audio_bytes = None
|
94 |
+
|
95 |
+
if host is not None and port is not None:
|
96 |
+
socket_protocol = 'wss' if self.use_wss else "ws"
|
97 |
+
socket_url = f"{socket_protocol}://{host}:{port}"
|
98 |
+
self.client_socket = websocket.WebSocketApp(
|
99 |
+
socket_url,
|
100 |
+
on_open=lambda ws: self.on_open(ws),
|
101 |
+
on_message=lambda ws, message: self.on_message(ws, message),
|
102 |
+
on_error=lambda ws, error: self.on_error(ws, error),
|
103 |
+
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
104 |
+
ws, close_status_code, close_msg
|
105 |
+
),
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
print("[ERROR]: No host or port specified.")
|
109 |
+
return
|
110 |
+
|
111 |
+
Client.INSTANCES[self.uid] = self
|
112 |
+
|
113 |
+
# start websocket client in a thread
|
114 |
+
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
115 |
+
self.ws_thread.daemon = True
|
116 |
+
self.ws_thread.start()
|
117 |
+
|
118 |
+
self.transcript = []
|
119 |
+
print("[INFO]: * recording")
|
120 |
+
|
121 |
+
def handle_status_messages(self, message_data):
|
122 |
+
"""Handles server status messages."""
|
123 |
+
status = message_data["status"]
|
124 |
+
if status == "WAIT":
|
125 |
+
self.waiting = True
|
126 |
+
print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
|
127 |
+
elif status == "ERROR":
|
128 |
+
print(f"Message from Server: {message_data['message']}")
|
129 |
+
self.server_error = True
|
130 |
+
elif status == "WARNING":
|
131 |
+
print(f"Message from Server: {message_data['message']}")
|
132 |
+
|
133 |
+
def process_segments(self, segments):
|
134 |
+
"""Processes transcript segments."""
|
135 |
+
text = []
|
136 |
+
for i, seg in enumerate(segments):
|
137 |
+
if not text or text[-1] != seg["text"]:
|
138 |
+
text.append(seg["text"])
|
139 |
+
if i == len(segments) - 1 and not seg.get("completed", False):
|
140 |
+
self.last_segment = seg
|
141 |
+
elif (self.server_backend == "faster_whisper" and seg.get("completed", False) and
|
142 |
+
(not self.transcript or
|
143 |
+
float(seg['start']) >= float(self.transcript[-1]['end']))):
|
144 |
+
self.transcript.append(seg)
|
145 |
+
# update last received segment and last valid response time
|
146 |
+
if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
|
147 |
+
self.last_response_received = time.time()
|
148 |
+
self.last_received_segment = segments[-1]["text"]
|
149 |
+
|
150 |
+
# call the transcription callback if provided
|
151 |
+
if self.transcription_callback and callable(self.transcription_callback):
|
152 |
+
try:
|
153 |
+
self.transcription_callback(" ".join(text), segments) # string, list
|
154 |
+
except Exception as e:
|
155 |
+
print(f"[WARN] transcription_callback raised: {e}")
|
156 |
+
return
|
157 |
+
|
158 |
+
if self.log_transcription:
|
159 |
+
# Truncate to last 3 entries for brevity.
|
160 |
+
text = text[-3:]
|
161 |
+
utils.clear_screen()
|
162 |
+
utils.print_transcript(text)
|
163 |
+
|
164 |
+
def on_message(self, ws, message):
|
165 |
+
"""
|
166 |
+
Callback function called when a message is received from the server.
|
167 |
+
|
168 |
+
It updates various attributes of the client based on the received message, including
|
169 |
+
recording status, language detection, and server messages. If a disconnect message
|
170 |
+
is received, it sets the recording status to False.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
174 |
+
message (str): The received message from the server.
|
175 |
+
|
176 |
+
"""
|
177 |
+
message = json.loads(message)
|
178 |
+
|
179 |
+
if self.uid != message.get("uid"):
|
180 |
+
print("[ERROR]: invalid client uid")
|
181 |
+
return
|
182 |
+
|
183 |
+
if "status" in message.keys():
|
184 |
+
self.handle_status_messages(message)
|
185 |
+
return
|
186 |
+
|
187 |
+
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
188 |
+
print("[INFO]: Server disconnected due to overtime.")
|
189 |
+
self.recording = False
|
190 |
+
|
191 |
+
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
192 |
+
self.last_response_received = time.time()
|
193 |
+
self.recording = True
|
194 |
+
self.server_backend = message["backend"]
|
195 |
+
print(f"[INFO]: Server Running with backend {self.server_backend}")
|
196 |
+
return
|
197 |
+
|
198 |
+
if "language" in message.keys():
|
199 |
+
self.language = message.get("language")
|
200 |
+
lang_prob = message.get("language_prob")
|
201 |
+
print(
|
202 |
+
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
203 |
+
)
|
204 |
+
return
|
205 |
+
|
206 |
+
if "segments" in message.keys():
|
207 |
+
self.process_segments(message["segments"])
|
208 |
+
|
209 |
+
def on_error(self, ws, error):
|
210 |
+
print(f"[ERROR] WebSocket Error: {error}")
|
211 |
+
self.server_error = True
|
212 |
+
self.error_message = error
|
213 |
+
|
214 |
+
def on_close(self, ws, close_status_code, close_msg):
|
215 |
+
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
216 |
+
self.recording = False
|
217 |
+
self.waiting = False
|
218 |
+
|
219 |
+
def on_open(self, ws):
|
220 |
+
"""
|
221 |
+
Callback function called when the WebSocket connection is successfully opened.
|
222 |
+
|
223 |
+
Sends an initial configuration message to the server, including client UID,
|
224 |
+
language selection, and task type.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
228 |
+
|
229 |
+
"""
|
230 |
+
print("[INFO]: Opened connection")
|
231 |
+
ws.send(
|
232 |
+
json.dumps(
|
233 |
+
{
|
234 |
+
"uid": self.uid,
|
235 |
+
"language": self.language,
|
236 |
+
"task": self.task,
|
237 |
+
"model": self.model,
|
238 |
+
"use_vad": self.use_vad,
|
239 |
+
"max_clients": self.max_clients,
|
240 |
+
"max_connection_time": self.max_connection_time,
|
241 |
+
"send_last_n_segments": self.send_last_n_segments,
|
242 |
+
"no_speech_thresh": self.no_speech_thresh,
|
243 |
+
"clip_audio": self.clip_audio,
|
244 |
+
"same_output_threshold": self.same_output_threshold,
|
245 |
+
}
|
246 |
+
)
|
247 |
+
)
|
248 |
+
|
249 |
+
def send_packet_to_server(self, message):
|
250 |
+
"""
|
251 |
+
Send an audio packet to the server using WebSocket.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
message (bytes): The audio data packet in bytes to be sent to the server.
|
255 |
+
|
256 |
+
"""
|
257 |
+
try:
|
258 |
+
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
259 |
+
except Exception as e:
|
260 |
+
print(e)
|
261 |
+
|
262 |
+
def close_websocket(self):
|
263 |
+
"""
|
264 |
+
Close the WebSocket connection and join the WebSocket thread.
|
265 |
+
|
266 |
+
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
267 |
+
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
268 |
+
|
269 |
+
"""
|
270 |
+
try:
|
271 |
+
self.client_socket.close()
|
272 |
+
except Exception as e:
|
273 |
+
print("[ERROR]: Error closing WebSocket:", e)
|
274 |
+
|
275 |
+
try:
|
276 |
+
self.ws_thread.join()
|
277 |
+
except Exception as e:
|
278 |
+
print("[ERROR:] Error joining WebSocket thread:", e)
|
279 |
+
|
280 |
+
def get_client_socket(self):
|
281 |
+
"""
|
282 |
+
Get the WebSocket client socket instance.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
286 |
+
"""
|
287 |
+
return self.client_socket
|
288 |
+
|
289 |
+
def write_srt_file(self, output_path="output.srt"):
|
290 |
+
"""
|
291 |
+
Writes out the transcript in .srt format.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
message (output_path, optional): The path to the target file. Default is "output.srt".
|
295 |
+
|
296 |
+
"""
|
297 |
+
if self.server_backend == "faster_whisper":
|
298 |
+
if not self.transcript and self.last_segment is not None:
|
299 |
+
self.transcript.append(self.last_segment)
|
300 |
+
elif self.last_segment and self.transcript[-1]["text"] != self.last_segment["text"]:
|
301 |
+
self.transcript.append(self.last_segment)
|
302 |
+
utils.create_srt_file(self.transcript, output_path)
|
303 |
+
|
304 |
+
def wait_before_disconnect(self):
|
305 |
+
"""Waits a bit before disconnecting in order to process pending responses."""
|
306 |
+
assert self.last_response_received
|
307 |
+
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
|
308 |
+
continue
|
309 |
+
|
310 |
+
|
311 |
+
class TranscriptionTeeClient:
|
312 |
+
"""
|
313 |
+
Client for handling audio recording, streaming, and transcription tasks via one or more
|
314 |
+
WebSocket connections.
|
315 |
+
|
316 |
+
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
317 |
+
to send audio data for transcription to one or more servers, and receive transcribed text segments.
|
318 |
+
Args:
|
319 |
+
clients (list): one or more previously initialized Client instances
|
320 |
+
|
321 |
+
Attributes:
|
322 |
+
clients (list): the underlying Client instances responsible for handling WebSocket connections.
|
323 |
+
"""
|
324 |
+
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav", mute_audio_playback=False):
|
325 |
+
self.clients = clients
|
326 |
+
if not self.clients:
|
327 |
+
raise Exception("At least one client is required.")
|
328 |
+
self.chunk = 4096
|
329 |
+
self.format = pyaudio.paInt16
|
330 |
+
self.channels = 1
|
331 |
+
self.rate = 16000
|
332 |
+
self.record_seconds = 60000
|
333 |
+
self.save_output_recording = save_output_recording
|
334 |
+
self.output_recording_filename = output_recording_filename
|
335 |
+
self.mute_audio_playback = mute_audio_playback
|
336 |
+
self.frames = b""
|
337 |
+
self.p = pyaudio.PyAudio()
|
338 |
+
try:
|
339 |
+
self.stream = self.p.open(
|
340 |
+
format=self.format,
|
341 |
+
channels=self.channels,
|
342 |
+
rate=self.rate,
|
343 |
+
input=True,
|
344 |
+
frames_per_buffer=self.chunk,
|
345 |
+
)
|
346 |
+
except OSError as error:
|
347 |
+
print(f"[WARN]: Unable to access microphone. {error}")
|
348 |
+
self.stream = None
|
349 |
+
|
350 |
+
def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
|
351 |
+
"""
|
352 |
+
Start the transcription process.
|
353 |
+
|
354 |
+
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
355 |
+
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
356 |
+
will be played and streamed to the server; otherwise, it will perform live recording.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
360 |
+
|
361 |
+
"""
|
362 |
+
assert sum(
|
363 |
+
source is not None for source in [audio, rtsp_url, hls_url]
|
364 |
+
) <= 1, 'You must provide only one selected source'
|
365 |
+
|
366 |
+
print("[INFO]: Waiting for server ready ...")
|
367 |
+
for client in self.clients:
|
368 |
+
while not client.recording:
|
369 |
+
if client.waiting or client.server_error:
|
370 |
+
self.close_all_clients()
|
371 |
+
return
|
372 |
+
|
373 |
+
print("[INFO]: Server Ready!")
|
374 |
+
if hls_url is not None:
|
375 |
+
self.process_hls_stream(hls_url, save_file)
|
376 |
+
elif audio is not None:
|
377 |
+
resampled_file = utils.resample(audio)
|
378 |
+
self.play_file(resampled_file)
|
379 |
+
elif rtsp_url is not None:
|
380 |
+
self.process_rtsp_stream(rtsp_url)
|
381 |
+
else:
|
382 |
+
self.record()
|
383 |
+
|
384 |
+
def close_all_clients(self):
|
385 |
+
"""Closes all client websockets."""
|
386 |
+
for client in self.clients:
|
387 |
+
client.close_websocket()
|
388 |
+
|
389 |
+
def write_all_clients_srt(self):
|
390 |
+
"""Writes out .srt files for all clients."""
|
391 |
+
for client in self.clients:
|
392 |
+
client.write_srt_file(client.srt_file_path)
|
393 |
+
|
394 |
+
def multicast_packet(self, packet, unconditional=False):
|
395 |
+
"""
|
396 |
+
Sends an identical packet via all clients.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
packet (bytes): The audio data packet in bytes to be sent.
|
400 |
+
unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
|
401 |
+
"""
|
402 |
+
for client in self.clients:
|
403 |
+
if (unconditional or client.recording):
|
404 |
+
client.send_packet_to_server(packet)
|
405 |
+
|
406 |
+
def play_file(self, filename):
|
407 |
+
"""
|
408 |
+
Play an audio file and send it to the server for processing.
|
409 |
+
|
410 |
+
Reads an audio file, plays it through the audio output, and simultaneously sends
|
411 |
+
the audio data to the server for processing. It uses PyAudio to create an audio
|
412 |
+
stream for playback. The audio data is read from the file in chunks, converted to
|
413 |
+
floating-point format, and sent to the server using WebSocket communication.
|
414 |
+
This method is typically used when you want to process pre-recorded audio and send it
|
415 |
+
to the server in real-time.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
filename (str): The path to the audio file to be played and sent to the server.
|
419 |
+
"""
|
420 |
+
|
421 |
+
# read audio and create pyaudio stream
|
422 |
+
with wave.open(filename, "rb") as wavfile:
|
423 |
+
self.stream = self.p.open(
|
424 |
+
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
425 |
+
channels=wavfile.getnchannels(),
|
426 |
+
rate=wavfile.getframerate(),
|
427 |
+
input=True,
|
428 |
+
output=True,
|
429 |
+
frames_per_buffer=self.chunk,
|
430 |
+
)
|
431 |
+
chunk_duration = self.chunk / float(wavfile.getframerate())
|
432 |
+
try:
|
433 |
+
while any(client.recording for client in self.clients):
|
434 |
+
data = wavfile.readframes(self.chunk)
|
435 |
+
if data == b"":
|
436 |
+
break
|
437 |
+
|
438 |
+
audio_array = self.bytes_to_float_array(data)
|
439 |
+
self.multicast_packet(audio_array.tobytes())
|
440 |
+
if self.mute_audio_playback:
|
441 |
+
time.sleep(chunk_duration)
|
442 |
+
else:
|
443 |
+
self.stream.write(data)
|
444 |
+
|
445 |
+
wavfile.close()
|
446 |
+
|
447 |
+
for client in self.clients:
|
448 |
+
client.wait_before_disconnect()
|
449 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
450 |
+
self.write_all_clients_srt()
|
451 |
+
self.stream.close()
|
452 |
+
self.close_all_clients()
|
453 |
+
|
454 |
+
except KeyboardInterrupt:
|
455 |
+
wavfile.close()
|
456 |
+
self.stream.stop_stream()
|
457 |
+
self.stream.close()
|
458 |
+
self.p.terminate()
|
459 |
+
self.close_all_clients()
|
460 |
+
self.write_all_clients_srt()
|
461 |
+
print("[INFO]: Keyboard interrupt.")
|
462 |
+
|
463 |
+
def process_rtsp_stream(self, rtsp_url):
|
464 |
+
"""
|
465 |
+
Connect to an RTSP source, process the audio stream, and send it for transcription.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
rtsp_url (str): The URL of the RTSP stream source.
|
469 |
+
"""
|
470 |
+
print("[INFO]: Connecting to RTSP stream...")
|
471 |
+
try:
|
472 |
+
container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
|
473 |
+
self.process_av_stream(container, stream_type="RTSP")
|
474 |
+
except Exception as e:
|
475 |
+
print(f"[ERROR]: Failed to process RTSP stream: {e}")
|
476 |
+
finally:
|
477 |
+
for client in self.clients:
|
478 |
+
client.wait_before_disconnect()
|
479 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
480 |
+
self.close_all_clients()
|
481 |
+
self.write_all_clients_srt()
|
482 |
+
print("[INFO]: RTSP stream processing finished.")
|
483 |
+
|
484 |
+
def process_hls_stream(self, hls_url, save_file=None):
|
485 |
+
"""
|
486 |
+
Connect to an HLS source, process the audio stream, and send it for transcription.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
hls_url (str): The URL of the HLS stream source.
|
490 |
+
save_file (str, optional): Local path to save the network stream.
|
491 |
+
"""
|
492 |
+
print("[INFO]: Connecting to HLS stream...")
|
493 |
+
try:
|
494 |
+
container = av.open(hls_url, format="hls")
|
495 |
+
self.process_av_stream(container, stream_type="HLS", save_file=save_file)
|
496 |
+
except Exception as e:
|
497 |
+
print(f"[ERROR]: Failed to process HLS stream: {e}")
|
498 |
+
finally:
|
499 |
+
for client in self.clients:
|
500 |
+
client.wait_before_disconnect()
|
501 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
502 |
+
self.close_all_clients()
|
503 |
+
self.write_all_clients_srt()
|
504 |
+
print("[INFO]: HLS stream processing finished.")
|
505 |
+
|
506 |
+
def process_av_stream(self, container, stream_type, save_file=None):
|
507 |
+
"""
|
508 |
+
Process an AV container stream and send audio packets to the server.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
container (av.container.InputContainer): The input container to process.
|
512 |
+
stream_type (str): The type of stream being processed ("RTSP" or "HLS").
|
513 |
+
save_file (str, optional): Local path to save the stream. Default is None.
|
514 |
+
"""
|
515 |
+
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
516 |
+
if not audio_stream:
|
517 |
+
print(f"[ERROR]: No audio stream found in {stream_type} source.")
|
518 |
+
return
|
519 |
+
|
520 |
+
output_container = None
|
521 |
+
if save_file:
|
522 |
+
output_container = av.open(save_file, mode="w")
|
523 |
+
output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
|
524 |
+
|
525 |
+
try:
|
526 |
+
for packet in container.demux(audio_stream):
|
527 |
+
for frame in packet.decode():
|
528 |
+
audio_data = frame.to_ndarray().tobytes()
|
529 |
+
self.multicast_packet(audio_data)
|
530 |
+
|
531 |
+
if save_file:
|
532 |
+
output_container.mux(frame)
|
533 |
+
except Exception as e:
|
534 |
+
print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
|
535 |
+
finally:
|
536 |
+
# Wait for server to send any leftover transcription.
|
537 |
+
time.sleep(5)
|
538 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
539 |
+
if output_container:
|
540 |
+
output_container.close()
|
541 |
+
container.close()
|
542 |
+
|
543 |
+
def save_chunk(self, n_audio_file):
|
544 |
+
"""
|
545 |
+
Saves the current audio frames to a WAV file in a separate thread.
|
546 |
+
|
547 |
+
Args:
|
548 |
+
n_audio_file (int): The index of the audio file which determines the filename.
|
549 |
+
This helps in maintaining the order and uniqueness of each chunk.
|
550 |
+
"""
|
551 |
+
t = threading.Thread(
|
552 |
+
target=self.write_audio_frames_to_file,
|
553 |
+
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
|
554 |
+
)
|
555 |
+
t.start()
|
556 |
+
|
557 |
+
def finalize_recording(self, n_audio_file):
|
558 |
+
"""
|
559 |
+
Finalizes the recording process by saving any remaining audio frames,
|
560 |
+
closing the audio stream, and terminating the process.
|
561 |
+
|
562 |
+
Args:
|
563 |
+
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
|
564 |
+
This index is incremented before use if the last chunk is saved.
|
565 |
+
"""
|
566 |
+
if self.save_output_recording and len(self.frames):
|
567 |
+
self.write_audio_frames_to_file(
|
568 |
+
self.frames[:], f"chunks/{n_audio_file}.wav"
|
569 |
+
)
|
570 |
+
n_audio_file += 1
|
571 |
+
self.stream.stop_stream()
|
572 |
+
self.stream.close()
|
573 |
+
self.p.terminate()
|
574 |
+
self.close_all_clients()
|
575 |
+
if self.save_output_recording:
|
576 |
+
self.write_output_recording(n_audio_file)
|
577 |
+
self.write_all_clients_srt()
|
578 |
+
|
579 |
+
def record(self):
|
580 |
+
"""
|
581 |
+
Record audio data from the input stream and save it to a WAV file.
|
582 |
+
|
583 |
+
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
584 |
+
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
585 |
+
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
586 |
+
|
587 |
+
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
588 |
+
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
589 |
+
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
590 |
+
the method combines all the saved audio chunks into the specified `out_file`.
|
591 |
+
"""
|
592 |
+
n_audio_file = 0
|
593 |
+
if self.save_output_recording:
|
594 |
+
if os.path.exists("chunks"):
|
595 |
+
shutil.rmtree("chunks")
|
596 |
+
os.makedirs("chunks")
|
597 |
+
try:
|
598 |
+
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
599 |
+
if not any(client.recording for client in self.clients):
|
600 |
+
break
|
601 |
+
data = self.stream.read(self.chunk, exception_on_overflow=False)
|
602 |
+
self.frames += data
|
603 |
+
|
604 |
+
audio_array = self.bytes_to_float_array(data)
|
605 |
+
|
606 |
+
self.multicast_packet(audio_array.tobytes())
|
607 |
+
|
608 |
+
# save frames if more than a minute
|
609 |
+
if len(self.frames) > 60 * self.rate:
|
610 |
+
if self.save_output_recording:
|
611 |
+
self.save_chunk(n_audio_file)
|
612 |
+
n_audio_file += 1
|
613 |
+
self.frames = b""
|
614 |
+
self.write_all_clients_srt()
|
615 |
+
|
616 |
+
except KeyboardInterrupt:
|
617 |
+
self.finalize_recording(n_audio_file)
|
618 |
+
|
619 |
+
def write_audio_frames_to_file(self, frames, file_name):
|
620 |
+
"""
|
621 |
+
Write audio frames to a WAV file.
|
622 |
+
|
623 |
+
The WAV file is created or overwritten with the specified name. The audio frames should be
|
624 |
+
in the correct format and match the specified channel, sample width, and sample rate.
|
625 |
+
|
626 |
+
Args:
|
627 |
+
frames (bytes): The audio frames to be written to the file.
|
628 |
+
file_name (str): The name of the WAV file to which the frames will be written.
|
629 |
+
|
630 |
+
"""
|
631 |
+
with wave.open(file_name, "wb") as wavfile:
|
632 |
+
wavfile: wave.Wave_write
|
633 |
+
wavfile.setnchannels(self.channels)
|
634 |
+
wavfile.setsampwidth(2)
|
635 |
+
wavfile.setframerate(self.rate)
|
636 |
+
wavfile.writeframes(frames)
|
637 |
+
|
638 |
+
def write_output_recording(self, n_audio_file):
|
639 |
+
"""
|
640 |
+
Combine and save recorded audio chunks into a single WAV file.
|
641 |
+
|
642 |
+
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
643 |
+
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
644 |
+
and saving, the final recording is stored in the specified `out_file`.
|
645 |
+
|
646 |
+
|
647 |
+
Args:
|
648 |
+
n_audio_file (int): The number of audio chunk files to combine.
|
649 |
+
out_file (str): The name of the output WAV file to save the final recording.
|
650 |
+
|
651 |
+
"""
|
652 |
+
input_files = [
|
653 |
+
f"chunks/{i}.wav"
|
654 |
+
for i in range(n_audio_file)
|
655 |
+
if os.path.exists(f"chunks/{i}.wav")
|
656 |
+
]
|
657 |
+
with wave.open(self.output_recording_filename, "wb") as wavfile:
|
658 |
+
wavfile: wave.Wave_write
|
659 |
+
wavfile.setnchannels(self.channels)
|
660 |
+
wavfile.setsampwidth(2)
|
661 |
+
wavfile.setframerate(self.rate)
|
662 |
+
for in_file in input_files:
|
663 |
+
with wave.open(in_file, "rb") as wav_in:
|
664 |
+
while True:
|
665 |
+
data = wav_in.readframes(self.chunk)
|
666 |
+
if data == b"":
|
667 |
+
break
|
668 |
+
wavfile.writeframes(data)
|
669 |
+
# remove this file
|
670 |
+
os.remove(in_file)
|
671 |
+
wavfile.close()
|
672 |
+
# clean up temporary directory to store chunks
|
673 |
+
if os.path.exists("chunks"):
|
674 |
+
shutil.rmtree("chunks")
|
675 |
+
|
676 |
+
@staticmethod
|
677 |
+
def bytes_to_float_array(audio_bytes):
|
678 |
+
"""
|
679 |
+
Convert audio data from bytes to a NumPy float array.
|
680 |
+
|
681 |
+
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
682 |
+
have values between -1 and 1.
|
683 |
+
|
684 |
+
Args:
|
685 |
+
audio_bytes (bytes): Audio data in bytes.
|
686 |
+
|
687 |
+
Returns:
|
688 |
+
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
689 |
+
"""
|
690 |
+
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
691 |
+
return raw_data.astype(np.float32) / 32768.0
|
692 |
+
|
693 |
+
|
694 |
+
class TranscriptionClient(TranscriptionTeeClient):
|
695 |
+
"""
|
696 |
+
Client for handling audio transcription tasks via a single WebSocket connection.
|
697 |
+
|
698 |
+
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
699 |
+
to send audio data for transcription to a server and receive transcribed text segments.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
host (str): The hostname or IP address of the server.
|
703 |
+
port (int): The port number to connect to on the server.
|
704 |
+
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
|
705 |
+
translate (bool, optional): If True, the task will be translation instead of transcription. Default is False.
|
706 |
+
model (str, optional): The whisper model to use (e.g., "small", "base"). Default is "small".
|
707 |
+
use_vad (bool, optional): Whether to enable voice activity detection. Default is True.
|
708 |
+
save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
|
709 |
+
output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
|
710 |
+
output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
|
711 |
+
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
712 |
+
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
713 |
+
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
714 |
+
mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
|
715 |
+
send_last_n_segments (int, optional): Number of most recent segments to send to the client. Defaults to 10.
|
716 |
+
no_speech_thresh (float, optional): Segments with no speech probability above this threshold will be discarded. Defaults to 0.45.
|
717 |
+
clip_audio (bool, optional): Whether to clip audio with no valid segments. Defaults to False.
|
718 |
+
same_output_threshold (int, optional): Number of repeated outputs before considering it as a valid segment. Defaults to 10.
|
719 |
+
transcription_callback (callable, optional): A callback function to handle transcription results. Default is None.
|
720 |
+
|
721 |
+
Attributes:
|
722 |
+
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
723 |
+
|
724 |
+
Example:
|
725 |
+
To create a TranscriptionClient and start transcription on microphone audio:
|
726 |
+
```python
|
727 |
+
transcription_client = TranscriptionClient(host="localhost", port=9090)
|
728 |
+
transcription_client()
|
729 |
+
```
|
730 |
+
"""
|
731 |
+
def __init__(
|
732 |
+
self,
|
733 |
+
host,
|
734 |
+
port,
|
735 |
+
lang=None,
|
736 |
+
translate=False,
|
737 |
+
model="small",
|
738 |
+
use_vad=True,
|
739 |
+
use_wss=False,
|
740 |
+
save_output_recording=False,
|
741 |
+
output_recording_filename="./output_recording.wav",
|
742 |
+
output_transcription_path="./output.srt",
|
743 |
+
log_transcription=True,
|
744 |
+
max_clients=4,
|
745 |
+
max_connection_time=600,
|
746 |
+
mute_audio_playback=False,
|
747 |
+
send_last_n_segments=10,
|
748 |
+
no_speech_thresh=0.45,
|
749 |
+
clip_audio=False,
|
750 |
+
same_output_threshold=10,
|
751 |
+
transcription_callback=None,
|
752 |
+
):
|
753 |
+
self.client = Client(
|
754 |
+
host,
|
755 |
+
port,
|
756 |
+
lang,
|
757 |
+
translate,
|
758 |
+
model,
|
759 |
+
srt_file_path=output_transcription_path,
|
760 |
+
use_vad=use_vad,
|
761 |
+
use_wss=use_wss,
|
762 |
+
log_transcription=log_transcription,
|
763 |
+
max_clients=max_clients,
|
764 |
+
max_connection_time=max_connection_time,
|
765 |
+
send_last_n_segments=send_last_n_segments,
|
766 |
+
no_speech_thresh=no_speech_thresh,
|
767 |
+
clip_audio=clip_audio,
|
768 |
+
same_output_threshold=same_output_threshold,
|
769 |
+
transcription_callback=transcription_callback,
|
770 |
+
)
|
771 |
+
|
772 |
+
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
773 |
+
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
|
774 |
+
if not output_transcription_path.endswith(".srt"):
|
775 |
+
raise ValueError(f"Please provide a valid `output_transcription_path`: {output_transcription_path}. The file extension should be `.srt`.")
|
776 |
+
TranscriptionTeeClient.__init__(
|
777 |
+
self,
|
778 |
+
[self.client],
|
779 |
+
save_output_recording=save_output_recording,
|
780 |
+
output_recording_filename=output_recording_filename,
|
781 |
+
mute_audio_playback=mute_audio_playback
|
782 |
+
)
|
whisper_live/server.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import threading
|
4 |
+
import json
|
5 |
+
import functools
|
6 |
+
import logging
|
7 |
+
from enum import Enum
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from websockets.sync.server import serve
|
12 |
+
from websockets.exceptions import ConnectionClosed
|
13 |
+
from whisper_live.vad import VoiceActivityDetector
|
14 |
+
from whisper_live.backend.base import ServeClientBase
|
15 |
+
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
|
19 |
+
class ClientManager:
|
20 |
+
def __init__(self, max_clients=4, max_connection_time=600):
|
21 |
+
"""
|
22 |
+
Initializes the ClientManager with specified limits on client connections and connection durations.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
26 |
+
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
27 |
+
to 600 seconds (10 minutes).
|
28 |
+
"""
|
29 |
+
self.clients = {}
|
30 |
+
self.start_times = {}
|
31 |
+
self.max_clients = max_clients
|
32 |
+
self.max_connection_time = max_connection_time
|
33 |
+
|
34 |
+
def add_client(self, websocket, client):
|
35 |
+
"""
|
36 |
+
Adds a client and their connection start time to the tracking dictionaries.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
websocket: The websocket associated with the client to add.
|
40 |
+
client: The client object to be added and tracked.
|
41 |
+
"""
|
42 |
+
self.clients[websocket] = client
|
43 |
+
self.start_times[websocket] = time.time()
|
44 |
+
|
45 |
+
def get_client(self, websocket):
|
46 |
+
"""
|
47 |
+
Retrieves a client associated with the given websocket.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
websocket: The websocket associated with the client to retrieve.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
The client object if found, False otherwise.
|
54 |
+
"""
|
55 |
+
if websocket in self.clients:
|
56 |
+
return self.clients[websocket]
|
57 |
+
return False
|
58 |
+
|
59 |
+
def remove_client(self, websocket):
|
60 |
+
"""
|
61 |
+
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
62 |
+
client if necessary.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
websocket: The websocket associated with the client to be removed.
|
66 |
+
"""
|
67 |
+
client = self.clients.pop(websocket, None)
|
68 |
+
if client:
|
69 |
+
client.cleanup()
|
70 |
+
self.start_times.pop(websocket, None)
|
71 |
+
|
72 |
+
def get_wait_time(self):
|
73 |
+
"""
|
74 |
+
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
78 |
+
"""
|
79 |
+
wait_time = None
|
80 |
+
for start_time in self.start_times.values():
|
81 |
+
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
82 |
+
if wait_time is None or current_client_time_remaining < wait_time:
|
83 |
+
wait_time = current_client_time_remaining
|
84 |
+
return wait_time / 60 if wait_time is not None else 0
|
85 |
+
|
86 |
+
def is_server_full(self, websocket, options):
|
87 |
+
"""
|
88 |
+
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
websocket: The websocket of the client attempting to connect.
|
92 |
+
options: A dictionary of options that may include the client's unique identifier.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
True if the server is full, False otherwise.
|
96 |
+
"""
|
97 |
+
if len(self.clients) >= self.max_clients:
|
98 |
+
wait_time = self.get_wait_time()
|
99 |
+
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
100 |
+
websocket.send(json.dumps(response))
|
101 |
+
return True
|
102 |
+
return False
|
103 |
+
|
104 |
+
def is_client_timeout(self, websocket):
|
105 |
+
"""
|
106 |
+
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
websocket: The websocket associated with the client to check.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
113 |
+
"""
|
114 |
+
elapsed_time = time.time() - self.start_times[websocket]
|
115 |
+
if elapsed_time >= self.max_connection_time:
|
116 |
+
self.clients[websocket].disconnect()
|
117 |
+
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
118 |
+
return True
|
119 |
+
return False
|
120 |
+
|
121 |
+
|
122 |
+
class BackendType(Enum):
|
123 |
+
FASTER_WHISPER = "faster_whisper"
|
124 |
+
TENSORRT = "tensorrt"
|
125 |
+
OPENVINO = "openvino"
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def valid_types() -> List[str]:
|
129 |
+
return [backend_type.value for backend_type in BackendType]
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def is_valid(backend: str) -> bool:
|
133 |
+
return backend in BackendType.valid_types()
|
134 |
+
|
135 |
+
def is_faster_whisper(self) -> bool:
|
136 |
+
return self == BackendType.FASTER_WHISPER
|
137 |
+
|
138 |
+
def is_tensorrt(self) -> bool:
|
139 |
+
return self == BackendType.TENSORRT
|
140 |
+
|
141 |
+
def is_openvino(self) -> bool:
|
142 |
+
return self == BackendType.OPENVINO
|
143 |
+
|
144 |
+
|
145 |
+
class TranscriptionServer:
|
146 |
+
RATE = 16000
|
147 |
+
|
148 |
+
def __init__(self):
|
149 |
+
self.client_manager = None
|
150 |
+
self.no_voice_activity_chunks = 0
|
151 |
+
self.use_vad = True
|
152 |
+
self.single_model = False
|
153 |
+
|
154 |
+
def initialize_client(
|
155 |
+
self, websocket, options, faster_whisper_custom_model_path,
|
156 |
+
whisper_tensorrt_path, trt_multilingual, trt_py_session=False,
|
157 |
+
):
|
158 |
+
client: Optional[ServeClientBase] = None
|
159 |
+
|
160 |
+
if self.backend.is_tensorrt():
|
161 |
+
try:
|
162 |
+
from whisper_live.backend.trt_backend import ServeClientTensorRT
|
163 |
+
client = ServeClientTensorRT(
|
164 |
+
websocket,
|
165 |
+
multilingual=trt_multilingual,
|
166 |
+
language=options["language"],
|
167 |
+
task=options["task"],
|
168 |
+
client_uid=options["uid"],
|
169 |
+
model=whisper_tensorrt_path,
|
170 |
+
single_model=self.single_model,
|
171 |
+
use_py_session=trt_py_session,
|
172 |
+
send_last_n_segments=options.get("send_last_n_segments", 10),
|
173 |
+
no_speech_thresh=options.get("no_speech_thresh", 0.45),
|
174 |
+
clip_audio=options.get("clip_audio", False),
|
175 |
+
same_output_threshold=options.get("same_output_threshold", 10),
|
176 |
+
)
|
177 |
+
logging.info("Running TensorRT backend.")
|
178 |
+
except Exception as e:
|
179 |
+
logging.error(f"TensorRT-LLM not supported: {e}")
|
180 |
+
self.client_uid = options["uid"]
|
181 |
+
websocket.send(json.dumps({
|
182 |
+
"uid": self.client_uid,
|
183 |
+
"status": "WARNING",
|
184 |
+
"message": "TensorRT-LLM not supported on Server yet. "
|
185 |
+
"Reverting to available backend: 'faster_whisper'"
|
186 |
+
}))
|
187 |
+
self.backend = BackendType.FASTER_WHISPER
|
188 |
+
|
189 |
+
if self.backend.is_openvino():
|
190 |
+
try:
|
191 |
+
from whisper_live.backend.openvino_backend import ServeClientOpenVINO
|
192 |
+
client = ServeClientOpenVINO(
|
193 |
+
websocket,
|
194 |
+
language=options["language"],
|
195 |
+
task=options["task"],
|
196 |
+
client_uid=options["uid"],
|
197 |
+
model=options["model"],
|
198 |
+
single_model=self.single_model,
|
199 |
+
send_last_n_segments=options.get("send_last_n_segments", 10),
|
200 |
+
no_speech_thresh=options.get("no_speech_thresh", 0.45),
|
201 |
+
clip_audio=options.get("clip_audio", False),
|
202 |
+
same_output_threshold=options.get("same_output_threshold", 10),
|
203 |
+
)
|
204 |
+
logging.info("Running OpenVINO backend.")
|
205 |
+
except Exception as e:
|
206 |
+
logging.error(f"OpenVINO not supported: {e}")
|
207 |
+
self.backend = BackendType.FASTER_WHISPER
|
208 |
+
self.client_uid = options["uid"]
|
209 |
+
websocket.send(json.dumps({
|
210 |
+
"uid": self.client_uid,
|
211 |
+
"status": "WARNING",
|
212 |
+
"message": "OpenVINO not supported on Server yet. "
|
213 |
+
"Reverting to available backend: 'faster_whisper'"
|
214 |
+
}))
|
215 |
+
|
216 |
+
try:
|
217 |
+
if self.backend.is_faster_whisper():
|
218 |
+
from whisper_live.backend.faster_whisper_backend import ServeClientFasterWhisper
|
219 |
+
if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path):
|
220 |
+
logging.info(f"Using custom model {faster_whisper_custom_model_path}")
|
221 |
+
options["model"] = faster_whisper_custom_model_path
|
222 |
+
client = ServeClientFasterWhisper(
|
223 |
+
websocket,
|
224 |
+
language=options["language"],
|
225 |
+
task=options["task"],
|
226 |
+
client_uid=options["uid"],
|
227 |
+
model=options["model"],
|
228 |
+
initial_prompt=options.get("initial_prompt"),
|
229 |
+
vad_parameters=options.get("vad_parameters"),
|
230 |
+
use_vad=self.use_vad,
|
231 |
+
single_model=self.single_model,
|
232 |
+
send_last_n_segments=options.get("send_last_n_segments", 10),
|
233 |
+
no_speech_thresh=options.get("no_speech_thresh", 0.45),
|
234 |
+
clip_audio=options.get("clip_audio", False),
|
235 |
+
same_output_threshold=options.get("same_output_threshold", 10),
|
236 |
+
)
|
237 |
+
|
238 |
+
logging.info("Running faster_whisper backend.")
|
239 |
+
except Exception as e:
|
240 |
+
logging.error(e)
|
241 |
+
return
|
242 |
+
|
243 |
+
if client is None:
|
244 |
+
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
245 |
+
|
246 |
+
self.client_manager.add_client(websocket, client)
|
247 |
+
|
248 |
+
def get_audio_from_websocket(self, websocket):
|
249 |
+
"""
|
250 |
+
Receives audio buffer from websocket and creates a numpy array out of it.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
websocket: The websocket to receive audio from.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
A numpy array containing the audio.
|
257 |
+
"""
|
258 |
+
frame_data = websocket.recv()
|
259 |
+
if frame_data == b"END_OF_AUDIO":
|
260 |
+
return False
|
261 |
+
return np.frombuffer(frame_data, dtype=np.float32)
|
262 |
+
|
263 |
+
def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
|
264 |
+
whisper_tensorrt_path, trt_multilingual, trt_py_session=False):
|
265 |
+
try:
|
266 |
+
logging.info("New client connected")
|
267 |
+
options = websocket.recv()
|
268 |
+
options = json.loads(options)
|
269 |
+
|
270 |
+
if self.client_manager is None:
|
271 |
+
max_clients = options.get('max_clients', 4)
|
272 |
+
max_connection_time = options.get('max_connection_time', 600)
|
273 |
+
self.client_manager = ClientManager(max_clients, max_connection_time)
|
274 |
+
|
275 |
+
self.use_vad = options.get('use_vad')
|
276 |
+
if self.client_manager.is_server_full(websocket, options):
|
277 |
+
websocket.close()
|
278 |
+
return False # Indicates that the connection should not continue
|
279 |
+
|
280 |
+
if self.backend.is_tensorrt():
|
281 |
+
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
282 |
+
self.initialize_client(websocket, options, faster_whisper_custom_model_path,
|
283 |
+
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session)
|
284 |
+
return True
|
285 |
+
except json.JSONDecodeError:
|
286 |
+
logging.error("Failed to decode JSON from client")
|
287 |
+
return False
|
288 |
+
except ConnectionClosed:
|
289 |
+
logging.info("Connection closed by client")
|
290 |
+
return False
|
291 |
+
except Exception as e:
|
292 |
+
logging.error(f"Error during new connection initialization: {str(e)}")
|
293 |
+
return False
|
294 |
+
|
295 |
+
def process_audio_frames(self, websocket):
|
296 |
+
frame_np = self.get_audio_from_websocket(websocket)
|
297 |
+
client = self.client_manager.get_client(websocket)
|
298 |
+
if frame_np is False:
|
299 |
+
if self.backend.is_tensorrt():
|
300 |
+
client.set_eos(True)
|
301 |
+
return False
|
302 |
+
|
303 |
+
if self.backend.is_tensorrt():
|
304 |
+
voice_active = self.voice_activity(websocket, frame_np)
|
305 |
+
if voice_active:
|
306 |
+
self.no_voice_activity_chunks = 0
|
307 |
+
client.set_eos(False)
|
308 |
+
if self.use_vad and not voice_active:
|
309 |
+
return True
|
310 |
+
|
311 |
+
client.add_frames(frame_np)
|
312 |
+
return True
|
313 |
+
|
314 |
+
def recv_audio(self,
|
315 |
+
websocket,
|
316 |
+
backend: BackendType = BackendType.FASTER_WHISPER,
|
317 |
+
faster_whisper_custom_model_path=None,
|
318 |
+
whisper_tensorrt_path=None,
|
319 |
+
trt_multilingual=False,
|
320 |
+
trt_py_session=False):
|
321 |
+
"""
|
322 |
+
Receive audio chunks from a client in an infinite loop.
|
323 |
+
|
324 |
+
Continuously receives audio frames from a connected client
|
325 |
+
over a WebSocket connection. It processes the audio frames using a
|
326 |
+
voice activity detection (VAD) model to determine if they contain speech
|
327 |
+
or not. If the audio frame contains speech, it is added to the client's
|
328 |
+
audio data for ASR.
|
329 |
+
If the maximum number of clients is reached, the method sends a
|
330 |
+
"WAIT" status to the client, indicating that they should wait
|
331 |
+
until a slot is available.
|
332 |
+
If a client's connection exceeds the maximum allowed time, it will
|
333 |
+
be disconnected, and the client's resources will be cleaned up.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
337 |
+
backend (str): The backend to run the server with.
|
338 |
+
faster_whisper_custom_model_path (str): path to custom faster whisper model.
|
339 |
+
whisper_tensorrt_path (str): Required for tensorrt backend.
|
340 |
+
trt_multilingual(bool): Only used for tensorrt, True if multilingual model.
|
341 |
+
|
342 |
+
Raises:
|
343 |
+
Exception: If there is an error during the audio frame processing.
|
344 |
+
"""
|
345 |
+
self.backend = backend
|
346 |
+
if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
|
347 |
+
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session):
|
348 |
+
return
|
349 |
+
|
350 |
+
try:
|
351 |
+
while not self.client_manager.is_client_timeout(websocket):
|
352 |
+
if not self.process_audio_frames(websocket):
|
353 |
+
break
|
354 |
+
except ConnectionClosed:
|
355 |
+
logging.info("Connection closed by client")
|
356 |
+
except Exception as e:
|
357 |
+
logging.error(f"Unexpected error: {str(e)}")
|
358 |
+
finally:
|
359 |
+
if self.client_manager.get_client(websocket):
|
360 |
+
self.cleanup(websocket)
|
361 |
+
websocket.close()
|
362 |
+
del websocket
|
363 |
+
|
364 |
+
def run(self,
|
365 |
+
host,
|
366 |
+
port=9090,
|
367 |
+
backend="tensorrt",
|
368 |
+
faster_whisper_custom_model_path=None,
|
369 |
+
whisper_tensorrt_path=None,
|
370 |
+
trt_multilingual=False,
|
371 |
+
trt_py_session=False,
|
372 |
+
single_model=False):
|
373 |
+
"""
|
374 |
+
Run the transcription server.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
host (str): The host address to bind the server.
|
378 |
+
port (int): The port number to bind the server.
|
379 |
+
"""
|
380 |
+
if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path):
|
381 |
+
raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.")
|
382 |
+
if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path):
|
383 |
+
raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.")
|
384 |
+
if single_model:
|
385 |
+
if faster_whisper_custom_model_path or whisper_tensorrt_path:
|
386 |
+
logging.info("Custom model option was provided. Switching to single model mode.")
|
387 |
+
self.single_model = True
|
388 |
+
# TODO: load model initially
|
389 |
+
else:
|
390 |
+
logging.info("Single model mode currently only works with custom models.")
|
391 |
+
if not BackendType.is_valid(backend):
|
392 |
+
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
393 |
+
with serve(
|
394 |
+
functools.partial(
|
395 |
+
self.recv_audio,
|
396 |
+
backend=BackendType(backend),
|
397 |
+
faster_whisper_custom_model_path=faster_whisper_custom_model_path,
|
398 |
+
whisper_tensorrt_path=whisper_tensorrt_path,
|
399 |
+
trt_multilingual=trt_multilingual,
|
400 |
+
trt_py_session=trt_py_session,
|
401 |
+
),
|
402 |
+
host,
|
403 |
+
port
|
404 |
+
) as server:
|
405 |
+
server.serve_forever()
|
406 |
+
|
407 |
+
def voice_activity(self, websocket, frame_np):
|
408 |
+
"""
|
409 |
+
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
410 |
+
|
411 |
+
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
412 |
+
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
413 |
+
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
414 |
+
speech detection to improve subsequent processing steps.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
websocket: The websocket associated with the current client. Used to retrieve the client object
|
418 |
+
from the client manager for state management.
|
419 |
+
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
420 |
+
the audio data for the current frame.
|
421 |
+
|
422 |
+
Returns:
|
423 |
+
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
424 |
+
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
425 |
+
end-of-speech (EOS) flag for the client.
|
426 |
+
"""
|
427 |
+
if not self.vad_detector(frame_np):
|
428 |
+
self.no_voice_activity_chunks += 1
|
429 |
+
if self.no_voice_activity_chunks > 3:
|
430 |
+
client = self.client_manager.get_client(websocket)
|
431 |
+
if not client.eos:
|
432 |
+
client.set_eos(True)
|
433 |
+
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
434 |
+
return False
|
435 |
+
return True
|
436 |
+
|
437 |
+
def cleanup(self, websocket):
|
438 |
+
"""
|
439 |
+
Cleans up resources associated with a given client's websocket.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
websocket: The websocket associated with the client to be cleaned up.
|
443 |
+
"""
|
444 |
+
if self.client_manager.get_client(websocket):
|
445 |
+
self.client_manager.remove_client(websocket)
|
446 |
+
|
whisper_live/transcriber/__init__.py
ADDED
File without changes
|
whisper_live/transcriber/tensorrt_utils.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
from collections import defaultdict
|
18 |
+
from functools import lru_cache
|
19 |
+
from pathlib import Path
|
20 |
+
from subprocess import CalledProcessError, run
|
21 |
+
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
22 |
+
|
23 |
+
import kaldialign
|
24 |
+
import numpy as np
|
25 |
+
import soundfile
|
26 |
+
import av
|
27 |
+
import wave
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from whisper_live.utils import resample
|
31 |
+
|
32 |
+
|
33 |
+
Pathlike = Union[str, Path]
|
34 |
+
|
35 |
+
SAMPLE_RATE = 16000
|
36 |
+
N_FFT = 400
|
37 |
+
HOP_LENGTH = 160
|
38 |
+
CHUNK_LENGTH = 30
|
39 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
40 |
+
|
41 |
+
|
42 |
+
def load_audio(file: str, sr: int = 16000):
|
43 |
+
"""
|
44 |
+
Open an audio file, resample it, and read as a mono waveform.
|
45 |
+
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
file: str
|
49 |
+
The audio file to open.
|
50 |
+
|
51 |
+
sr: int
|
52 |
+
The sample rate to resample the audio if necessary.
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
57 |
+
"""
|
58 |
+
resampled_file = resample(file, sr)
|
59 |
+
|
60 |
+
with wave.open(resampled_file, "rb") as wav_file:
|
61 |
+
num_frames = wav_file.getnframes()
|
62 |
+
raw_data = wav_file.readframes(num_frames)
|
63 |
+
|
64 |
+
audio_data = np.frombuffer(raw_data, dtype=np.int16)
|
65 |
+
|
66 |
+
audio_data = audio_data.astype(np.float32) / 32768.0
|
67 |
+
|
68 |
+
return audio_data
|
69 |
+
|
70 |
+
|
71 |
+
def load_audio_wav_format(wav_path):
|
72 |
+
# make sure audio in .wav format
|
73 |
+
assert wav_path.endswith(
|
74 |
+
'.wav'), f"Only support .wav format, but got {wav_path}"
|
75 |
+
waveform, sample_rate = soundfile.read(wav_path)
|
76 |
+
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
|
77 |
+
return waveform, sample_rate
|
78 |
+
|
79 |
+
|
80 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
81 |
+
"""
|
82 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
83 |
+
"""
|
84 |
+
if torch.is_tensor(array):
|
85 |
+
if array.shape[axis] > length:
|
86 |
+
array = array.index_select(dim=axis,
|
87 |
+
index=torch.arange(length,
|
88 |
+
device=array.device))
|
89 |
+
|
90 |
+
if array.shape[axis] < length:
|
91 |
+
pad_widths = [(0, 0)] * array.ndim
|
92 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
93 |
+
array = F.pad(array,
|
94 |
+
[pad for sizes in pad_widths[::-1] for pad in sizes])
|
95 |
+
else:
|
96 |
+
if array.shape[axis] > length:
|
97 |
+
array = array.take(indices=range(length), axis=axis)
|
98 |
+
|
99 |
+
if array.shape[axis] < length:
|
100 |
+
pad_widths = [(0, 0)] * array.ndim
|
101 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
102 |
+
array = np.pad(array, pad_widths)
|
103 |
+
|
104 |
+
return array
|
105 |
+
|
106 |
+
|
107 |
+
@lru_cache(maxsize=None)
|
108 |
+
def mel_filters(device,
|
109 |
+
n_mels: int,
|
110 |
+
mel_filters_dir: str = None) -> torch.Tensor:
|
111 |
+
"""
|
112 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
113 |
+
Allows decoupling librosa dependency; saved using:
|
114 |
+
|
115 |
+
np.savez_compressed(
|
116 |
+
"mel_filters.npz",
|
117 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
118 |
+
)
|
119 |
+
"""
|
120 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
121 |
+
if mel_filters_dir is None:
|
122 |
+
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
|
123 |
+
"mel_filters.npz")
|
124 |
+
else:
|
125 |
+
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
|
126 |
+
with np.load(mel_filters_path) as f:
|
127 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
128 |
+
|
129 |
+
|
130 |
+
def log_mel_spectrogram(
|
131 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
132 |
+
n_mels: int,
|
133 |
+
padding: int = 0,
|
134 |
+
device: Optional[Union[str, torch.device]] = None,
|
135 |
+
return_duration: bool = False,
|
136 |
+
mel_filters_dir: str = None,
|
137 |
+
):
|
138 |
+
"""
|
139 |
+
Compute the log-Mel spectrogram of
|
140 |
+
|
141 |
+
Parameters
|
142 |
+
----------
|
143 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
144 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
145 |
+
|
146 |
+
n_mels: int
|
147 |
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
148 |
+
|
149 |
+
padding: int
|
150 |
+
Number of zero samples to pad to the right
|
151 |
+
|
152 |
+
device: Optional[Union[str, torch.device]]
|
153 |
+
If given, the audio tensor is moved to this device before STFT
|
154 |
+
|
155 |
+
Returns
|
156 |
+
-------
|
157 |
+
torch.Tensor, shape = (80 or 128, n_frames)
|
158 |
+
A Tensor that contains the Mel spectrogram
|
159 |
+
"""
|
160 |
+
if not torch.is_tensor(audio):
|
161 |
+
if isinstance(audio, str):
|
162 |
+
if audio.endswith('.wav'):
|
163 |
+
audio, _ = load_audio_wav_format(audio)
|
164 |
+
else:
|
165 |
+
audio = load_audio(audio)
|
166 |
+
assert isinstance(audio,
|
167 |
+
np.ndarray), f"Unsupported audio type: {type(audio)}"
|
168 |
+
duration = audio.shape[-1] / SAMPLE_RATE
|
169 |
+
audio = pad_or_trim(audio, N_SAMPLES)
|
170 |
+
audio = audio.astype(np.float32)
|
171 |
+
audio = torch.from_numpy(audio)
|
172 |
+
|
173 |
+
if device is not None:
|
174 |
+
audio = audio.to(device)
|
175 |
+
if padding > 0:
|
176 |
+
audio = F.pad(audio, (0, padding))
|
177 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
178 |
+
stft = torch.stft(audio,
|
179 |
+
N_FFT,
|
180 |
+
HOP_LENGTH,
|
181 |
+
window=window,
|
182 |
+
return_complex=True)
|
183 |
+
magnitudes = stft[..., :-1].abs()**2
|
184 |
+
|
185 |
+
filters = mel_filters(audio.device, n_mels, mel_filters_dir)
|
186 |
+
mel_spec = filters @ magnitudes
|
187 |
+
|
188 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
189 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
190 |
+
log_spec = (log_spec + 4.0) / 4.0
|
191 |
+
if return_duration:
|
192 |
+
return log_spec, duration
|
193 |
+
else:
|
194 |
+
return log_spec
|
195 |
+
|
196 |
+
|
197 |
+
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str,
|
198 |
+
str]]) -> None:
|
199 |
+
"""Save predicted results and reference transcripts to a file.
|
200 |
+
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
201 |
+
Args:
|
202 |
+
filename:
|
203 |
+
File to save the results to.
|
204 |
+
texts:
|
205 |
+
An iterable of tuples. The first element is the cur_id, the second is
|
206 |
+
the reference transcript and the third element is the predicted result.
|
207 |
+
Returns:
|
208 |
+
Return None.
|
209 |
+
"""
|
210 |
+
with open(filename, "w") as f:
|
211 |
+
for cut_id, ref, hyp in texts:
|
212 |
+
print(f"{cut_id}:\tref={ref}", file=f)
|
213 |
+
print(f"{cut_id}:\thyp={hyp}", file=f)
|
214 |
+
|
215 |
+
|
216 |
+
def write_error_stats( # noqa: C901
|
217 |
+
f: TextIO,
|
218 |
+
test_set_name: str,
|
219 |
+
results: List[Tuple[str, str]],
|
220 |
+
enable_log: bool = True,
|
221 |
+
) -> float:
|
222 |
+
"""Write statistics based on predicted results and reference transcripts.
|
223 |
+
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
224 |
+
It will write the following to the given file:
|
225 |
+
|
226 |
+
- WER
|
227 |
+
- number of insertions, deletions, substitutions, corrects and total
|
228 |
+
reference words. For example::
|
229 |
+
|
230 |
+
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
231 |
+
reference words (2337 correct)
|
232 |
+
|
233 |
+
- The difference between the reference transcript and predicted result.
|
234 |
+
An instance is given below::
|
235 |
+
|
236 |
+
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
237 |
+
|
238 |
+
The above example shows that the reference word is `EDISON`,
|
239 |
+
but it is predicted to `ADDISON` (a substitution error).
|
240 |
+
|
241 |
+
Another example is::
|
242 |
+
|
243 |
+
FOR THE FIRST DAY (SIR->*) I THINK
|
244 |
+
|
245 |
+
The reference word `SIR` is missing in the predicted
|
246 |
+
results (a deletion error).
|
247 |
+
results:
|
248 |
+
An iterable of tuples. The first element is the cur_id, the second is
|
249 |
+
the reference transcript and the third element is the predicted result.
|
250 |
+
enable_log:
|
251 |
+
If True, also print detailed WER to the console.
|
252 |
+
Otherwise, it is written only to the given file.
|
253 |
+
Returns:
|
254 |
+
Return None.
|
255 |
+
"""
|
256 |
+
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
257 |
+
ins: Dict[str, int] = defaultdict(int)
|
258 |
+
dels: Dict[str, int] = defaultdict(int)
|
259 |
+
|
260 |
+
# `words` stores counts per word, as follows:
|
261 |
+
# corr, ref_sub, hyp_sub, ins, dels
|
262 |
+
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
263 |
+
num_corr = 0
|
264 |
+
ERR = "*"
|
265 |
+
for cut_id, ref, hyp in results:
|
266 |
+
ali = kaldialign.align(ref, hyp, ERR)
|
267 |
+
for ref_word, hyp_word in ali:
|
268 |
+
if ref_word == ERR:
|
269 |
+
ins[hyp_word] += 1
|
270 |
+
words[hyp_word][3] += 1
|
271 |
+
elif hyp_word == ERR:
|
272 |
+
dels[ref_word] += 1
|
273 |
+
words[ref_word][4] += 1
|
274 |
+
elif hyp_word != ref_word:
|
275 |
+
subs[(ref_word, hyp_word)] += 1
|
276 |
+
words[ref_word][1] += 1
|
277 |
+
words[hyp_word][2] += 1
|
278 |
+
else:
|
279 |
+
words[ref_word][0] += 1
|
280 |
+
num_corr += 1
|
281 |
+
ref_len = sum([len(r) for _, r, _ in results])
|
282 |
+
sub_errs = sum(subs.values())
|
283 |
+
ins_errs = sum(ins.values())
|
284 |
+
del_errs = sum(dels.values())
|
285 |
+
tot_errs = sub_errs + ins_errs + del_errs
|
286 |
+
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
287 |
+
|
288 |
+
if enable_log:
|
289 |
+
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
290 |
+
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
291 |
+
f"{del_errs} del, {sub_errs} sub ]")
|
292 |
+
|
293 |
+
print(f"%WER = {tot_err_rate}", file=f)
|
294 |
+
print(
|
295 |
+
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
296 |
+
f"{sub_errs} substitutions, over {ref_len} reference "
|
297 |
+
f"words ({num_corr} correct)",
|
298 |
+
file=f,
|
299 |
+
)
|
300 |
+
print(
|
301 |
+
"Search below for sections starting with PER-UTT DETAILS:, "
|
302 |
+
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
303 |
+
file=f,
|
304 |
+
)
|
305 |
+
|
306 |
+
print("", file=f)
|
307 |
+
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
308 |
+
for cut_id, ref, hyp in results:
|
309 |
+
ali = kaldialign.align(ref, hyp, ERR)
|
310 |
+
combine_successive_errors = True
|
311 |
+
if combine_successive_errors:
|
312 |
+
ali = [[[x], [y]] for x, y in ali]
|
313 |
+
for i in range(len(ali) - 1):
|
314 |
+
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
315 |
+
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
316 |
+
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
317 |
+
ali[i] = [[], []]
|
318 |
+
ali = [[
|
319 |
+
list(filter(lambda a: a != ERR, x)),
|
320 |
+
list(filter(lambda a: a != ERR, y)),
|
321 |
+
] for x, y in ali]
|
322 |
+
ali = list(filter(lambda x: x != [[], []], ali))
|
323 |
+
ali = [[
|
324 |
+
ERR if x == [] else " ".join(x),
|
325 |
+
ERR if y == [] else " ".join(y),
|
326 |
+
] for x, y in ali]
|
327 |
+
|
328 |
+
print(
|
329 |
+
f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else
|
330 |
+
f"({ref_word}->{hyp_word})"
|
331 |
+
for ref_word, hyp_word in ali)),
|
332 |
+
file=f,
|
333 |
+
)
|
334 |
+
|
335 |
+
print("", file=f)
|
336 |
+
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
337 |
+
|
338 |
+
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()],
|
339 |
+
reverse=True):
|
340 |
+
print(f"{count} {ref} -> {hyp}", file=f)
|
341 |
+
|
342 |
+
print("", file=f)
|
343 |
+
print("DELETIONS: count ref", file=f)
|
344 |
+
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
345 |
+
print(f"{count} {ref}", file=f)
|
346 |
+
|
347 |
+
print("", file=f)
|
348 |
+
print("INSERTIONS: count hyp", file=f)
|
349 |
+
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
350 |
+
print(f"{count} {hyp}", file=f)
|
351 |
+
|
352 |
+
print("", file=f)
|
353 |
+
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp",
|
354 |
+
file=f)
|
355 |
+
for _, word, counts in sorted([(sum(v[1:]), k, v)
|
356 |
+
for k, v in words.items()],
|
357 |
+
reverse=True):
|
358 |
+
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
359 |
+
tot_errs = ref_sub + hyp_sub + ins + dels
|
360 |
+
ref_count = corr + ref_sub + dels
|
361 |
+
hyp_count = corr + hyp_sub + ins
|
362 |
+
|
363 |
+
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
364 |
+
return float(tot_err_rate)
|
whisper_live/transcriber/transcriber_faster_whisper.py
ADDED
@@ -0,0 +1,1889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import zlib
|
8 |
+
|
9 |
+
from dataclasses import asdict, dataclass
|
10 |
+
from inspect import signature
|
11 |
+
from math import ceil
|
12 |
+
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
|
13 |
+
from warnings import warn
|
14 |
+
|
15 |
+
import ctranslate2
|
16 |
+
import numpy as np
|
17 |
+
import tokenizers
|
18 |
+
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from faster_whisper.audio import decode_audio, pad_or_trim
|
22 |
+
from faster_whisper.feature_extractor import FeatureExtractor
|
23 |
+
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
24 |
+
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
25 |
+
from faster_whisper.vad import (
|
26 |
+
SpeechTimestampsMap,
|
27 |
+
VadOptions,
|
28 |
+
collect_chunks,
|
29 |
+
get_speech_timestamps,
|
30 |
+
merge_segments,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class Word:
|
36 |
+
start: float
|
37 |
+
end: float
|
38 |
+
word: str
|
39 |
+
probability: float
|
40 |
+
|
41 |
+
def _asdict(self):
|
42 |
+
warn(
|
43 |
+
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
|
44 |
+
DeprecationWarning,
|
45 |
+
2,
|
46 |
+
)
|
47 |
+
return asdict(self)
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class Segment:
|
52 |
+
id: int
|
53 |
+
seek: int
|
54 |
+
start: float
|
55 |
+
end: float
|
56 |
+
text: str
|
57 |
+
tokens: List[int]
|
58 |
+
avg_logprob: float
|
59 |
+
compression_ratio: float
|
60 |
+
no_speech_prob: float
|
61 |
+
words: Optional[List[Word]]
|
62 |
+
temperature: Optional[float]
|
63 |
+
|
64 |
+
def _asdict(self):
|
65 |
+
warn(
|
66 |
+
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
|
67 |
+
DeprecationWarning,
|
68 |
+
2,
|
69 |
+
)
|
70 |
+
return asdict(self)
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class TranscriptionOptions:
|
75 |
+
beam_size: int
|
76 |
+
best_of: int
|
77 |
+
patience: float
|
78 |
+
length_penalty: float
|
79 |
+
repetition_penalty: float
|
80 |
+
no_repeat_ngram_size: int
|
81 |
+
log_prob_threshold: Optional[float]
|
82 |
+
no_speech_threshold: Optional[float]
|
83 |
+
compression_ratio_threshold: Optional[float]
|
84 |
+
condition_on_previous_text: bool
|
85 |
+
prompt_reset_on_temperature: float
|
86 |
+
temperatures: List[float]
|
87 |
+
initial_prompt: Optional[Union[str, Iterable[int]]]
|
88 |
+
prefix: Optional[str]
|
89 |
+
suppress_blank: bool
|
90 |
+
suppress_tokens: Optional[List[int]]
|
91 |
+
without_timestamps: bool
|
92 |
+
max_initial_timestamp: float
|
93 |
+
word_timestamps: bool
|
94 |
+
prepend_punctuations: str
|
95 |
+
append_punctuations: str
|
96 |
+
multilingual: bool
|
97 |
+
max_new_tokens: Optional[int]
|
98 |
+
clip_timestamps: Union[str, List[float]]
|
99 |
+
hallucination_silence_threshold: Optional[float]
|
100 |
+
hotwords: Optional[str]
|
101 |
+
|
102 |
+
|
103 |
+
@dataclass
|
104 |
+
class TranscriptionInfo:
|
105 |
+
language: str
|
106 |
+
language_probability: float
|
107 |
+
duration: float
|
108 |
+
duration_after_vad: float
|
109 |
+
all_language_probs: Optional[List[Tuple[str, float]]]
|
110 |
+
transcription_options: TranscriptionOptions
|
111 |
+
vad_options: VadOptions
|
112 |
+
|
113 |
+
|
114 |
+
class BatchedInferencePipeline:
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
model,
|
118 |
+
):
|
119 |
+
self.model: WhisperModel = model
|
120 |
+
self.last_speech_timestamp = 0.0
|
121 |
+
|
122 |
+
def forward(self, features, tokenizer, chunks_metadata, options):
|
123 |
+
encoder_output, outputs = self.generate_segment_batched(
|
124 |
+
features, tokenizer, options
|
125 |
+
)
|
126 |
+
|
127 |
+
segmented_outputs = []
|
128 |
+
segment_sizes = []
|
129 |
+
for chunk_metadata, output in zip(chunks_metadata, outputs):
|
130 |
+
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
|
131 |
+
segment_size = int(ceil(duration) * self.model.frames_per_second)
|
132 |
+
segment_sizes.append(segment_size)
|
133 |
+
(
|
134 |
+
subsegments,
|
135 |
+
seek,
|
136 |
+
single_timestamp_ending,
|
137 |
+
) = self.model._split_segments_by_timestamps(
|
138 |
+
tokenizer=tokenizer,
|
139 |
+
tokens=output["tokens"],
|
140 |
+
time_offset=chunk_metadata["start_time"],
|
141 |
+
segment_size=segment_size,
|
142 |
+
segment_duration=duration,
|
143 |
+
seek=0,
|
144 |
+
)
|
145 |
+
segmented_outputs.append(
|
146 |
+
[
|
147 |
+
dict(
|
148 |
+
text=tokenizer.decode(subsegment["tokens"]),
|
149 |
+
avg_logprob=output["avg_logprob"],
|
150 |
+
no_speech_prob=output["no_speech_prob"],
|
151 |
+
tokens=subsegment["tokens"],
|
152 |
+
start=subsegment["start"],
|
153 |
+
end=subsegment["end"],
|
154 |
+
compression_ratio=get_compression_ratio(
|
155 |
+
tokenizer.decode(subsegment["tokens"])
|
156 |
+
),
|
157 |
+
seek=int(
|
158 |
+
chunk_metadata["start_time"] * self.model.frames_per_second
|
159 |
+
),
|
160 |
+
)
|
161 |
+
for subsegment in subsegments
|
162 |
+
]
|
163 |
+
)
|
164 |
+
if options.word_timestamps:
|
165 |
+
self.last_speech_timestamp = self.model.add_word_timestamps(
|
166 |
+
segmented_outputs,
|
167 |
+
tokenizer,
|
168 |
+
encoder_output,
|
169 |
+
segment_sizes,
|
170 |
+
options.prepend_punctuations,
|
171 |
+
options.append_punctuations,
|
172 |
+
self.last_speech_timestamp,
|
173 |
+
)
|
174 |
+
|
175 |
+
return segmented_outputs
|
176 |
+
|
177 |
+
def generate_segment_batched(
|
178 |
+
self,
|
179 |
+
features: np.ndarray,
|
180 |
+
tokenizer: Tokenizer,
|
181 |
+
options: TranscriptionOptions,
|
182 |
+
):
|
183 |
+
batch_size = features.shape[0]
|
184 |
+
|
185 |
+
prompt = self.model.get_prompt(
|
186 |
+
tokenizer,
|
187 |
+
previous_tokens=(
|
188 |
+
tokenizer.encode(options.initial_prompt)
|
189 |
+
if options.initial_prompt is not None
|
190 |
+
else []
|
191 |
+
),
|
192 |
+
without_timestamps=options.without_timestamps,
|
193 |
+
hotwords=options.hotwords,
|
194 |
+
)
|
195 |
+
|
196 |
+
if options.max_new_tokens is not None:
|
197 |
+
max_length = len(prompt) + options.max_new_tokens
|
198 |
+
else:
|
199 |
+
max_length = self.model.max_length
|
200 |
+
|
201 |
+
if max_length > self.model.max_length:
|
202 |
+
raise ValueError(
|
203 |
+
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
204 |
+
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
205 |
+
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
206 |
+
f"`max_length` of the Whisper model: {self.model.max_length}. "
|
207 |
+
"You should either reduce the length of your prompt, or "
|
208 |
+
"reduce the value of `max_new_tokens`, "
|
209 |
+
f"so that their combined length is less that {self.model.max_length}."
|
210 |
+
)
|
211 |
+
|
212 |
+
encoder_output = self.model.encode(features)
|
213 |
+
prompts = [prompt.copy() for _ in range(batch_size)]
|
214 |
+
|
215 |
+
if options.multilingual:
|
216 |
+
language_tokens = [
|
217 |
+
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
|
218 |
+
for segment_langs in self.model.model.detect_language(encoder_output)
|
219 |
+
]
|
220 |
+
language_token_index = prompt.index(tokenizer.language)
|
221 |
+
|
222 |
+
for i, language_token in enumerate(language_tokens):
|
223 |
+
prompts[i][language_token_index] = language_token
|
224 |
+
|
225 |
+
results = self.model.model.generate(
|
226 |
+
encoder_output,
|
227 |
+
prompts,
|
228 |
+
beam_size=options.beam_size,
|
229 |
+
patience=options.patience,
|
230 |
+
length_penalty=options.length_penalty,
|
231 |
+
max_length=max_length,
|
232 |
+
suppress_blank=options.suppress_blank,
|
233 |
+
suppress_tokens=options.suppress_tokens,
|
234 |
+
return_scores=True,
|
235 |
+
return_no_speech_prob=True,
|
236 |
+
sampling_temperature=options.temperatures[0],
|
237 |
+
repetition_penalty=options.repetition_penalty,
|
238 |
+
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
239 |
+
)
|
240 |
+
|
241 |
+
output = []
|
242 |
+
for result in results:
|
243 |
+
# return scores
|
244 |
+
seq_len = len(result.sequences_ids[0])
|
245 |
+
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
246 |
+
|
247 |
+
output.append(
|
248 |
+
dict(
|
249 |
+
avg_logprob=cum_logprob / (seq_len + 1),
|
250 |
+
no_speech_prob=result.no_speech_prob,
|
251 |
+
tokens=result.sequences_ids[0],
|
252 |
+
)
|
253 |
+
)
|
254 |
+
|
255 |
+
return encoder_output, output
|
256 |
+
|
257 |
+
def transcribe(
|
258 |
+
self,
|
259 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
260 |
+
language: Optional[str] = None,
|
261 |
+
task: str = "transcribe",
|
262 |
+
log_progress: bool = False,
|
263 |
+
beam_size: int = 5,
|
264 |
+
best_of: int = 5,
|
265 |
+
patience: float = 1,
|
266 |
+
length_penalty: float = 1,
|
267 |
+
repetition_penalty: float = 1,
|
268 |
+
no_repeat_ngram_size: int = 0,
|
269 |
+
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
270 |
+
0.0,
|
271 |
+
0.2,
|
272 |
+
0.4,
|
273 |
+
0.6,
|
274 |
+
0.8,
|
275 |
+
1.0,
|
276 |
+
],
|
277 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
278 |
+
log_prob_threshold: Optional[float] = -1.0,
|
279 |
+
no_speech_threshold: Optional[float] = 0.6,
|
280 |
+
condition_on_previous_text: bool = True,
|
281 |
+
prompt_reset_on_temperature: float = 0.5,
|
282 |
+
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
283 |
+
prefix: Optional[str] = None,
|
284 |
+
suppress_blank: bool = True,
|
285 |
+
suppress_tokens: Optional[List[int]] = [-1],
|
286 |
+
without_timestamps: bool = True,
|
287 |
+
max_initial_timestamp: float = 1.0,
|
288 |
+
word_timestamps: bool = False,
|
289 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
290 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
291 |
+
multilingual: bool = False,
|
292 |
+
vad_filter: bool = True,
|
293 |
+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
294 |
+
max_new_tokens: Optional[int] = None,
|
295 |
+
chunk_length: Optional[int] = None,
|
296 |
+
clip_timestamps: Optional[List[dict]] = None,
|
297 |
+
hallucination_silence_threshold: Optional[float] = None,
|
298 |
+
batch_size: int = 8,
|
299 |
+
hotwords: Optional[str] = None,
|
300 |
+
language_detection_threshold: Optional[float] = 0.5,
|
301 |
+
language_detection_segments: int = 1,
|
302 |
+
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
303 |
+
"""transcribe audio in chunks in batched fashion and return with language info.
|
304 |
+
|
305 |
+
Arguments:
|
306 |
+
audio: Path to the input file (or a file-like object), or the audio waveform.
|
307 |
+
language: The language spoken in the audio. It should be a language code such
|
308 |
+
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
309 |
+
of audio.
|
310 |
+
task: Task to execute (transcribe or translate).
|
311 |
+
log_progress: whether to show progress bar or not.
|
312 |
+
beam_size: Beam size to use for decoding.
|
313 |
+
best_of: Number of candidates when sampling with non-zero temperature.
|
314 |
+
patience: Beam search patience factor.
|
315 |
+
length_penalty: Exponential length penalty constant.
|
316 |
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
317 |
+
(set > 1 to penalize).
|
318 |
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
|
319 |
+
temperature: Temperature for sampling. If a list or tuple is passed,
|
320 |
+
only the first value is used.
|
321 |
+
initial_prompt: Optional text string or iterable of token ids to provide as a
|
322 |
+
prompt for the each window.
|
323 |
+
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
324 |
+
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
325 |
+
of symbols as defined in `tokenizer.non_speech_tokens()`.
|
326 |
+
without_timestamps: Only sample text tokens.
|
327 |
+
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
328 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
329 |
+
Set as False.
|
330 |
+
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
331 |
+
with the next word
|
332 |
+
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
333 |
+
with the previous word
|
334 |
+
multilingual: Perform language detection on every segment.
|
335 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
336 |
+
without speech. This step is using the Silero VAD model
|
337 |
+
https://github.com/snakers4/silero-vad.
|
338 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
339 |
+
parameters and default values in the class `VadOptions`).
|
340 |
+
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
|
341 |
+
the maximum will be set by the default max_length.
|
342 |
+
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
343 |
+
default chunk_length of the FeatureExtractor.
|
344 |
+
clip_timestamps: Optionally provide list of dictionaries each containing "start" and
|
345 |
+
"end" keys that specify the start and end of the voiced region within
|
346 |
+
`chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
|
347 |
+
batch_size: the maximum number of parallel requests to model for decoding.
|
348 |
+
hotwords:
|
349 |
+
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
|
350 |
+
language_detection_threshold: If the maximum probability of the language tokens is
|
351 |
+
higher than this value, the language is detected.
|
352 |
+
language_detection_segments: Number of segments to consider for the language detection.
|
353 |
+
|
354 |
+
Unused Arguments
|
355 |
+
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
356 |
+
treat as failed.
|
357 |
+
log_prob_threshold: If the average log probability over sampled tokens is
|
358 |
+
below this value, treat as failed.
|
359 |
+
no_speech_threshold: If the no_speech probability is higher than this value AND
|
360 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
361 |
+
consider the segment as silent.
|
362 |
+
condition_on_previous_text: If True, the previous output of the model is provided
|
363 |
+
as a prompt for the next window; disabling may make the text inconsistent across
|
364 |
+
windows, but the model becomes less prone to getting stuck in a failure loop,
|
365 |
+
such as repetition looping or timestamps going out of sync. Set as False
|
366 |
+
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
367 |
+
Arg has effect only if condition_on_previous_text is True. Set at 0.5
|
368 |
+
prefix: Optional text to provide as a prefix at the beginning of each window.
|
369 |
+
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
370 |
+
hallucination_silence_threshold: Optional[float]
|
371 |
+
When word_timestamps is True, skip silent periods longer than this threshold
|
372 |
+
(in seconds) when a possible hallucination is detected. set as None.
|
373 |
+
Returns:
|
374 |
+
A tuple with:
|
375 |
+
|
376 |
+
- a generator over transcribed segments
|
377 |
+
- an instance of TranscriptionInfo
|
378 |
+
"""
|
379 |
+
|
380 |
+
sampling_rate = self.model.feature_extractor.sampling_rate
|
381 |
+
|
382 |
+
if multilingual and not self.model.model.is_multilingual:
|
383 |
+
self.model.logger.warning(
|
384 |
+
"The current model is English-only but the multilingual parameter is set to"
|
385 |
+
"True; setting to False instead."
|
386 |
+
)
|
387 |
+
multilingual = False
|
388 |
+
|
389 |
+
if not isinstance(audio, np.ndarray):
|
390 |
+
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
391 |
+
duration = audio.shape[0] / sampling_rate
|
392 |
+
|
393 |
+
chunk_length = chunk_length or self.model.feature_extractor.chunk_length
|
394 |
+
# if no segment split is provided, use vad_model and generate segments
|
395 |
+
if not clip_timestamps:
|
396 |
+
if vad_filter:
|
397 |
+
if vad_parameters is None:
|
398 |
+
vad_parameters = VadOptions(
|
399 |
+
max_speech_duration_s=chunk_length,
|
400 |
+
min_silence_duration_ms=160,
|
401 |
+
)
|
402 |
+
elif isinstance(vad_parameters, dict):
|
403 |
+
if "max_speech_duration_s" in vad_parameters.keys():
|
404 |
+
vad_parameters.pop("max_speech_duration_s")
|
405 |
+
|
406 |
+
vad_parameters = VadOptions(
|
407 |
+
**vad_parameters, max_speech_duration_s=chunk_length
|
408 |
+
)
|
409 |
+
|
410 |
+
active_segments = get_speech_timestamps(audio, vad_parameters)
|
411 |
+
clip_timestamps = merge_segments(active_segments, vad_parameters)
|
412 |
+
# run the audio if it is less than 30 sec even without clip_timestamps
|
413 |
+
elif duration < chunk_length:
|
414 |
+
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
|
415 |
+
else:
|
416 |
+
raise RuntimeError(
|
417 |
+
"No clip timestamps found. "
|
418 |
+
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
419 |
+
)
|
420 |
+
|
421 |
+
duration_after_vad = (
|
422 |
+
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
423 |
+
/ sampling_rate
|
424 |
+
)
|
425 |
+
|
426 |
+
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
|
427 |
+
features = (
|
428 |
+
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
|
429 |
+
if duration_after_vad
|
430 |
+
else []
|
431 |
+
)
|
432 |
+
|
433 |
+
all_language_probs = None
|
434 |
+
# detecting the language if not provided
|
435 |
+
if language is None:
|
436 |
+
if not self.model.model.is_multilingual:
|
437 |
+
language = "en"
|
438 |
+
language_probability = 1
|
439 |
+
else:
|
440 |
+
(
|
441 |
+
language,
|
442 |
+
language_probability,
|
443 |
+
all_language_probs,
|
444 |
+
) = self.model.detect_language(
|
445 |
+
features=np.concatenate(
|
446 |
+
features
|
447 |
+
+ [
|
448 |
+
np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
|
449 |
+
],
|
450 |
+
axis=1,
|
451 |
+
), # add a dummy feature to account for empty audio
|
452 |
+
language_detection_segments=language_detection_segments,
|
453 |
+
language_detection_threshold=language_detection_threshold,
|
454 |
+
)
|
455 |
+
|
456 |
+
self.model.logger.info(
|
457 |
+
"Detected language '%s' with probability %.2f",
|
458 |
+
language,
|
459 |
+
language_probability,
|
460 |
+
)
|
461 |
+
else:
|
462 |
+
if not self.model.model.is_multilingual and language != "en":
|
463 |
+
self.model.logger.warning(
|
464 |
+
"The current model is English-only but the language parameter is set to '%s'; "
|
465 |
+
"using 'en' instead." % language
|
466 |
+
)
|
467 |
+
language = "en"
|
468 |
+
|
469 |
+
language_probability = 1
|
470 |
+
|
471 |
+
tokenizer = Tokenizer(
|
472 |
+
self.model.hf_tokenizer,
|
473 |
+
self.model.model.is_multilingual,
|
474 |
+
task=task,
|
475 |
+
language=language,
|
476 |
+
)
|
477 |
+
|
478 |
+
features = (
|
479 |
+
np.stack([pad_or_trim(feature) for feature in features]) if features else []
|
480 |
+
)
|
481 |
+
|
482 |
+
options = TranscriptionOptions(
|
483 |
+
beam_size=beam_size,
|
484 |
+
best_of=best_of,
|
485 |
+
patience=patience,
|
486 |
+
length_penalty=length_penalty,
|
487 |
+
repetition_penalty=repetition_penalty,
|
488 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
489 |
+
log_prob_threshold=log_prob_threshold,
|
490 |
+
no_speech_threshold=no_speech_threshold,
|
491 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
492 |
+
temperatures=(
|
493 |
+
temperature[:1]
|
494 |
+
if isinstance(temperature, (list, tuple))
|
495 |
+
else [temperature]
|
496 |
+
),
|
497 |
+
initial_prompt=initial_prompt,
|
498 |
+
prefix=prefix,
|
499 |
+
suppress_blank=suppress_blank,
|
500 |
+
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
501 |
+
prepend_punctuations=prepend_punctuations,
|
502 |
+
append_punctuations=append_punctuations,
|
503 |
+
max_new_tokens=max_new_tokens,
|
504 |
+
hotwords=hotwords,
|
505 |
+
word_timestamps=word_timestamps,
|
506 |
+
hallucination_silence_threshold=None,
|
507 |
+
condition_on_previous_text=False,
|
508 |
+
clip_timestamps=clip_timestamps,
|
509 |
+
prompt_reset_on_temperature=0.5,
|
510 |
+
multilingual=multilingual,
|
511 |
+
without_timestamps=without_timestamps,
|
512 |
+
max_initial_timestamp=0.0,
|
513 |
+
)
|
514 |
+
|
515 |
+
info = TranscriptionInfo(
|
516 |
+
language=language,
|
517 |
+
language_probability=language_probability,
|
518 |
+
duration=duration,
|
519 |
+
duration_after_vad=duration_after_vad,
|
520 |
+
transcription_options=options,
|
521 |
+
vad_options=vad_parameters,
|
522 |
+
all_language_probs=all_language_probs,
|
523 |
+
)
|
524 |
+
|
525 |
+
segments = self._batched_segments_generator(
|
526 |
+
features,
|
527 |
+
tokenizer,
|
528 |
+
chunks_metadata,
|
529 |
+
batch_size,
|
530 |
+
options,
|
531 |
+
log_progress,
|
532 |
+
)
|
533 |
+
|
534 |
+
return segments, info
|
535 |
+
|
536 |
+
def _batched_segments_generator(
|
537 |
+
self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
|
538 |
+
):
|
539 |
+
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
|
540 |
+
seg_idx = 0
|
541 |
+
for i in range(0, len(features), batch_size):
|
542 |
+
results = self.forward(
|
543 |
+
features[i : i + batch_size],
|
544 |
+
tokenizer,
|
545 |
+
chunks_metadata[i : i + batch_size],
|
546 |
+
options,
|
547 |
+
)
|
548 |
+
|
549 |
+
for result in results:
|
550 |
+
for segment in result:
|
551 |
+
seg_idx += 1
|
552 |
+
yield Segment(
|
553 |
+
seek=segment["seek"],
|
554 |
+
id=seg_idx,
|
555 |
+
text=segment["text"],
|
556 |
+
start=round(segment["start"], 3),
|
557 |
+
end=round(segment["end"], 3),
|
558 |
+
words=(
|
559 |
+
None
|
560 |
+
if not options.word_timestamps
|
561 |
+
else [Word(**word) for word in segment["words"]]
|
562 |
+
),
|
563 |
+
tokens=segment["tokens"],
|
564 |
+
avg_logprob=segment["avg_logprob"],
|
565 |
+
no_speech_prob=segment["no_speech_prob"],
|
566 |
+
compression_ratio=segment["compression_ratio"],
|
567 |
+
temperature=options.temperatures[0],
|
568 |
+
)
|
569 |
+
|
570 |
+
pbar.update(1)
|
571 |
+
|
572 |
+
pbar.close()
|
573 |
+
self.last_speech_timestamp = 0.0
|
574 |
+
|
575 |
+
|
576 |
+
class WhisperModel:
|
577 |
+
def __init__(
|
578 |
+
self,
|
579 |
+
model_size_or_path: str,
|
580 |
+
device: str = "auto",
|
581 |
+
device_index: Union[int, List[int]] = 0,
|
582 |
+
compute_type: str = "default",
|
583 |
+
cpu_threads: int = 0,
|
584 |
+
num_workers: int = 1,
|
585 |
+
download_root: Optional[str] = None,
|
586 |
+
local_files_only: bool = False,
|
587 |
+
files: dict = None,
|
588 |
+
**model_kwargs,
|
589 |
+
):
|
590 |
+
"""Initializes the Whisper model.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
594 |
+
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
|
595 |
+
large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo),
|
596 |
+
a path to a converted model directory, or a CTranslate2-converted Whisper model ID from
|
597 |
+
the HF Hub. When a size or a model ID is configured, the converted model is downloaded
|
598 |
+
from the Hugging Face Hub.
|
599 |
+
device: Device to use for computation ("cpu", "cuda", "auto").
|
600 |
+
device_index: Device ID to use.
|
601 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
602 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
603 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
604 |
+
compute_type: Type to use for computation.
|
605 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
606 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
607 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
608 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
609 |
+
having multiple workers enables true parallelism when running the model
|
610 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
611 |
+
This can improve the global throughput at the cost of increased memory usage.
|
612 |
+
download_root: Directory where the models should be saved. If not set, the models
|
613 |
+
are saved in the standard Hugging Face cache directory.
|
614 |
+
local_files_only: If True, avoid downloading the file and return the path to the
|
615 |
+
local cached file if it exists.
|
616 |
+
files: Load model files from the memory. This argument is a dictionary mapping file names
|
617 |
+
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
618 |
+
identifier for this model.
|
619 |
+
"""
|
620 |
+
self.logger = get_logger()
|
621 |
+
|
622 |
+
tokenizer_bytes, preprocessor_bytes = None, None
|
623 |
+
if files:
|
624 |
+
model_path = model_size_or_path
|
625 |
+
tokenizer_bytes = files.pop("tokenizer.json", None)
|
626 |
+
preprocessor_bytes = files.pop("preprocessor_config.json", None)
|
627 |
+
elif os.path.isdir(model_size_or_path):
|
628 |
+
model_path = model_size_or_path
|
629 |
+
else:
|
630 |
+
model_path = download_model(
|
631 |
+
model_size_or_path,
|
632 |
+
local_files_only=local_files_only,
|
633 |
+
cache_dir=download_root,
|
634 |
+
)
|
635 |
+
|
636 |
+
self.model = ctranslate2.models.Whisper(
|
637 |
+
model_path,
|
638 |
+
device=device,
|
639 |
+
device_index=device_index,
|
640 |
+
compute_type=compute_type,
|
641 |
+
intra_threads=cpu_threads,
|
642 |
+
inter_threads=num_workers,
|
643 |
+
files=files,
|
644 |
+
**model_kwargs,
|
645 |
+
)
|
646 |
+
|
647 |
+
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
648 |
+
if tokenizer_bytes:
|
649 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
|
650 |
+
elif os.path.isfile(tokenizer_file):
|
651 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
652 |
+
else:
|
653 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
654 |
+
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
655 |
+
)
|
656 |
+
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
657 |
+
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
658 |
+
self.input_stride = 2
|
659 |
+
self.num_samples_per_token = (
|
660 |
+
self.feature_extractor.hop_length * self.input_stride
|
661 |
+
)
|
662 |
+
self.frames_per_second = (
|
663 |
+
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
664 |
+
)
|
665 |
+
self.tokens_per_second = (
|
666 |
+
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
667 |
+
)
|
668 |
+
self.time_precision = 0.02
|
669 |
+
self.max_length = 448
|
670 |
+
|
671 |
+
@property
|
672 |
+
def supported_languages(self) -> List[str]:
|
673 |
+
"""The languages supported by the model."""
|
674 |
+
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
675 |
+
|
676 |
+
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
|
677 |
+
config = {}
|
678 |
+
try:
|
679 |
+
config_path = os.path.join(model_path, "preprocessor_config.json")
|
680 |
+
if preprocessor_bytes:
|
681 |
+
config = json.loads(preprocessor_bytes)
|
682 |
+
elif os.path.isfile(config_path):
|
683 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
684 |
+
config = json.load(file)
|
685 |
+
else:
|
686 |
+
return config
|
687 |
+
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
688 |
+
return {k: v for k, v in config.items() if k in valid_keys}
|
689 |
+
except json.JSONDecodeError as e:
|
690 |
+
self.logger.warning("Could not load preprocessor config: %s", e)
|
691 |
+
|
692 |
+
return config
|
693 |
+
|
694 |
+
def transcribe(
|
695 |
+
self,
|
696 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
697 |
+
language: Optional[str] = None,
|
698 |
+
task: str = "transcribe",
|
699 |
+
log_progress: bool = False,
|
700 |
+
beam_size: int = 5,
|
701 |
+
best_of: int = 5,
|
702 |
+
patience: float = 1,
|
703 |
+
length_penalty: float = 1,
|
704 |
+
repetition_penalty: float = 1,
|
705 |
+
no_repeat_ngram_size: int = 0,
|
706 |
+
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
707 |
+
0.0,
|
708 |
+
0.2,
|
709 |
+
0.4,
|
710 |
+
0.6,
|
711 |
+
0.8,
|
712 |
+
1.0,
|
713 |
+
],
|
714 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
715 |
+
log_prob_threshold: Optional[float] = -1.0,
|
716 |
+
no_speech_threshold: Optional[float] = 0.6,
|
717 |
+
condition_on_previous_text: bool = True,
|
718 |
+
prompt_reset_on_temperature: float = 0.5,
|
719 |
+
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
720 |
+
prefix: Optional[str] = None,
|
721 |
+
suppress_blank: bool = True,
|
722 |
+
suppress_tokens: Optional[List[int]] = [-1],
|
723 |
+
without_timestamps: bool = False,
|
724 |
+
max_initial_timestamp: float = 1.0,
|
725 |
+
word_timestamps: bool = False,
|
726 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
727 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
728 |
+
multilingual: bool = False,
|
729 |
+
vad_filter: bool = False,
|
730 |
+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
731 |
+
max_new_tokens: Optional[int] = None,
|
732 |
+
chunk_length: Optional[int] = None,
|
733 |
+
clip_timestamps: Union[str, List[float]] = "0",
|
734 |
+
hallucination_silence_threshold: Optional[float] = None,
|
735 |
+
hotwords: Optional[str] = None,
|
736 |
+
language_detection_threshold: Optional[float] = 0.5,
|
737 |
+
language_detection_segments: int = 1,
|
738 |
+
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
739 |
+
"""Transcribes an input file.
|
740 |
+
|
741 |
+
Arguments:
|
742 |
+
audio: Path to the input file (or a file-like object), or the audio waveform.
|
743 |
+
language: The language spoken in the audio. It should be a language code such
|
744 |
+
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
745 |
+
of audio.
|
746 |
+
task: Task to execute (transcribe or translate).
|
747 |
+
log_progress: whether to show progress bar or not.
|
748 |
+
beam_size: Beam size to use for decoding.
|
749 |
+
best_of: Number of candidates when sampling with non-zero temperature.
|
750 |
+
patience: Beam search patience factor.
|
751 |
+
length_penalty: Exponential length penalty constant.
|
752 |
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
753 |
+
(set > 1 to penalize).
|
754 |
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
|
755 |
+
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
756 |
+
which will be successively used upon failures according to either
|
757 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
758 |
+
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
759 |
+
treat as failed.
|
760 |
+
log_prob_threshold: If the average log probability over sampled tokens is
|
761 |
+
below this value, treat as failed.
|
762 |
+
no_speech_threshold: If the no_speech probability is higher than this value AND
|
763 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
764 |
+
consider the segment as silent.
|
765 |
+
condition_on_previous_text: If True, the previous output of the model is provided
|
766 |
+
as a prompt for the next window; disabling may make the text inconsistent across
|
767 |
+
windows, but the model becomes less prone to getting stuck in a failure loop,
|
768 |
+
such as repetition looping or timestamps going out of sync.
|
769 |
+
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
770 |
+
Arg has effect only if condition_on_previous_text is True.
|
771 |
+
initial_prompt: Optional text string or iterable of token ids to provide as a
|
772 |
+
prompt for the first window.
|
773 |
+
prefix: Optional text to provide as a prefix for the first window.
|
774 |
+
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
775 |
+
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
776 |
+
of symbols as defined in `tokenizer.non_speech_tokens()`.
|
777 |
+
without_timestamps: Only sample text tokens.
|
778 |
+
max_initial_timestamp: The initial timestamp cannot be later than this.
|
779 |
+
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
780 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
781 |
+
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
782 |
+
with the next word
|
783 |
+
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
784 |
+
with the previous word
|
785 |
+
multilingual: Perform language detection on every segment.
|
786 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
787 |
+
without speech. This step is using the Silero VAD model
|
788 |
+
https://github.com/snakers4/silero-vad.
|
789 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
790 |
+
parameters and default values in the class `VadOptions`).
|
791 |
+
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
|
792 |
+
the maximum will be set by the default max_length.
|
793 |
+
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
794 |
+
default chunk_length of the FeatureExtractor.
|
795 |
+
clip_timestamps:
|
796 |
+
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
797 |
+
process. The last end timestamp defaults to the end of the file.
|
798 |
+
vad_filter will be ignored if clip_timestamps is used.
|
799 |
+
hallucination_silence_threshold:
|
800 |
+
When word_timestamps is True, skip silent periods longer than this threshold
|
801 |
+
(in seconds) when a possible hallucination is detected
|
802 |
+
hotwords:
|
803 |
+
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
804 |
+
language_detection_threshold: If the maximum probability of the language tokens is higher
|
805 |
+
than this value, the language is detected.
|
806 |
+
language_detection_segments: Number of segments to consider for the language detection.
|
807 |
+
Returns:
|
808 |
+
A tuple with:
|
809 |
+
|
810 |
+
- a generator over transcribed segments
|
811 |
+
- an instance of TranscriptionInfo
|
812 |
+
"""
|
813 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
814 |
+
|
815 |
+
if multilingual and not self.model.is_multilingual:
|
816 |
+
self.logger.warning(
|
817 |
+
"The current model is English-only but the multilingual parameter is set to"
|
818 |
+
"True; setting to False instead."
|
819 |
+
)
|
820 |
+
multilingual = False
|
821 |
+
|
822 |
+
if not isinstance(audio, np.ndarray):
|
823 |
+
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
824 |
+
|
825 |
+
duration = audio.shape[0] / sampling_rate
|
826 |
+
duration_after_vad = duration
|
827 |
+
|
828 |
+
self.logger.info(
|
829 |
+
"Processing audio with duration %s", format_timestamp(duration)
|
830 |
+
)
|
831 |
+
|
832 |
+
if vad_filter and clip_timestamps == "0":
|
833 |
+
if vad_parameters is None:
|
834 |
+
vad_parameters = VadOptions()
|
835 |
+
elif isinstance(vad_parameters, dict):
|
836 |
+
vad_parameters = VadOptions(**vad_parameters)
|
837 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
838 |
+
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
839 |
+
audio = np.concatenate(audio_chunks, axis=0)
|
840 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
841 |
+
|
842 |
+
self.logger.info(
|
843 |
+
"VAD filter removed %s of audio",
|
844 |
+
format_timestamp(duration - duration_after_vad),
|
845 |
+
)
|
846 |
+
|
847 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
848 |
+
self.logger.debug(
|
849 |
+
"VAD filter kept the following audio segments: %s",
|
850 |
+
", ".join(
|
851 |
+
"[%s -> %s]"
|
852 |
+
% (
|
853 |
+
format_timestamp(chunk["start"] / sampling_rate),
|
854 |
+
format_timestamp(chunk["end"] / sampling_rate),
|
855 |
+
)
|
856 |
+
for chunk in speech_chunks
|
857 |
+
),
|
858 |
+
)
|
859 |
+
|
860 |
+
else:
|
861 |
+
speech_chunks = None
|
862 |
+
if audio.shape[0] == 0:
|
863 |
+
return None, None
|
864 |
+
features = self.feature_extractor(audio, chunk_length=chunk_length)
|
865 |
+
|
866 |
+
encoder_output = None
|
867 |
+
all_language_probs = None
|
868 |
+
|
869 |
+
# detecting the language if not provided
|
870 |
+
if language is None:
|
871 |
+
if not self.model.is_multilingual:
|
872 |
+
language = "en"
|
873 |
+
language_probability = 1
|
874 |
+
else:
|
875 |
+
start_timestamp = (
|
876 |
+
float(clip_timestamps.split(",")[0])
|
877 |
+
if isinstance(clip_timestamps, str)
|
878 |
+
else clip_timestamps[0]
|
879 |
+
)
|
880 |
+
content_frames = features.shape[-1] - 1
|
881 |
+
seek = (
|
882 |
+
int(start_timestamp * self.frames_per_second)
|
883 |
+
if start_timestamp * self.frames_per_second < content_frames
|
884 |
+
else 0
|
885 |
+
)
|
886 |
+
(
|
887 |
+
language,
|
888 |
+
language_probability,
|
889 |
+
all_language_probs,
|
890 |
+
) = self.detect_language(
|
891 |
+
features=features[..., seek:],
|
892 |
+
language_detection_segments=language_detection_segments,
|
893 |
+
language_detection_threshold=language_detection_threshold,
|
894 |
+
)
|
895 |
+
|
896 |
+
self.logger.info(
|
897 |
+
"Detected language '%s' with probability %.2f",
|
898 |
+
language,
|
899 |
+
language_probability,
|
900 |
+
)
|
901 |
+
else:
|
902 |
+
if not self.model.is_multilingual and language != "en":
|
903 |
+
self.logger.warning(
|
904 |
+
"The current model is English-only but the language parameter is set to '%s'; "
|
905 |
+
"using 'en' instead." % language
|
906 |
+
)
|
907 |
+
language = "en"
|
908 |
+
|
909 |
+
language_probability = 1
|
910 |
+
|
911 |
+
tokenizer = Tokenizer(
|
912 |
+
self.hf_tokenizer,
|
913 |
+
self.model.is_multilingual,
|
914 |
+
task=task,
|
915 |
+
language=language,
|
916 |
+
)
|
917 |
+
|
918 |
+
options = TranscriptionOptions(
|
919 |
+
beam_size=beam_size,
|
920 |
+
best_of=best_of,
|
921 |
+
patience=patience,
|
922 |
+
length_penalty=length_penalty,
|
923 |
+
repetition_penalty=repetition_penalty,
|
924 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
925 |
+
log_prob_threshold=log_prob_threshold,
|
926 |
+
no_speech_threshold=no_speech_threshold,
|
927 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
928 |
+
condition_on_previous_text=condition_on_previous_text,
|
929 |
+
prompt_reset_on_temperature=prompt_reset_on_temperature,
|
930 |
+
temperatures=(
|
931 |
+
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
932 |
+
),
|
933 |
+
initial_prompt=initial_prompt,
|
934 |
+
prefix=prefix,
|
935 |
+
suppress_blank=suppress_blank,
|
936 |
+
suppress_tokens=(
|
937 |
+
get_suppressed_tokens(tokenizer, suppress_tokens)
|
938 |
+
if suppress_tokens
|
939 |
+
else suppress_tokens
|
940 |
+
),
|
941 |
+
without_timestamps=without_timestamps,
|
942 |
+
max_initial_timestamp=max_initial_timestamp,
|
943 |
+
word_timestamps=word_timestamps,
|
944 |
+
prepend_punctuations=prepend_punctuations,
|
945 |
+
append_punctuations=append_punctuations,
|
946 |
+
multilingual=multilingual,
|
947 |
+
max_new_tokens=max_new_tokens,
|
948 |
+
clip_timestamps=clip_timestamps,
|
949 |
+
hallucination_silence_threshold=hallucination_silence_threshold,
|
950 |
+
hotwords=hotwords,
|
951 |
+
)
|
952 |
+
|
953 |
+
segments = self.generate_segments(
|
954 |
+
features, tokenizer, options, log_progress, encoder_output
|
955 |
+
)
|
956 |
+
|
957 |
+
if speech_chunks:
|
958 |
+
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
959 |
+
|
960 |
+
info = TranscriptionInfo(
|
961 |
+
language=language,
|
962 |
+
language_probability=language_probability,
|
963 |
+
duration=duration,
|
964 |
+
duration_after_vad=duration_after_vad,
|
965 |
+
transcription_options=options,
|
966 |
+
vad_options=vad_parameters,
|
967 |
+
all_language_probs=all_language_probs,
|
968 |
+
)
|
969 |
+
|
970 |
+
return segments, info
|
971 |
+
|
972 |
+
def _split_segments_by_timestamps(
|
973 |
+
self,
|
974 |
+
tokenizer: Tokenizer,
|
975 |
+
tokens: List[int],
|
976 |
+
time_offset: float,
|
977 |
+
segment_size: int,
|
978 |
+
segment_duration: float,
|
979 |
+
seek: int,
|
980 |
+
) -> List[List[int]]:
|
981 |
+
current_segments = []
|
982 |
+
single_timestamp_ending = (
|
983 |
+
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
|
984 |
+
)
|
985 |
+
|
986 |
+
consecutive_timestamps = [
|
987 |
+
i
|
988 |
+
for i in range(len(tokens))
|
989 |
+
if i > 0
|
990 |
+
and tokens[i] >= tokenizer.timestamp_begin
|
991 |
+
and tokens[i - 1] >= tokenizer.timestamp_begin
|
992 |
+
]
|
993 |
+
|
994 |
+
if len(consecutive_timestamps) > 0:
|
995 |
+
slices = list(consecutive_timestamps)
|
996 |
+
if single_timestamp_ending:
|
997 |
+
slices.append(len(tokens))
|
998 |
+
|
999 |
+
last_slice = 0
|
1000 |
+
for current_slice in slices:
|
1001 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
1002 |
+
start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin
|
1003 |
+
end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin
|
1004 |
+
start_time = (
|
1005 |
+
time_offset + start_timestamp_position * self.time_precision
|
1006 |
+
)
|
1007 |
+
end_time = time_offset + end_timestamp_position * self.time_precision
|
1008 |
+
|
1009 |
+
current_segments.append(
|
1010 |
+
dict(
|
1011 |
+
seek=seek,
|
1012 |
+
start=start_time,
|
1013 |
+
end=end_time,
|
1014 |
+
tokens=sliced_tokens,
|
1015 |
+
)
|
1016 |
+
)
|
1017 |
+
last_slice = current_slice
|
1018 |
+
|
1019 |
+
if single_timestamp_ending:
|
1020 |
+
# single timestamp at the end means no speech after the last timestamp.
|
1021 |
+
seek += segment_size
|
1022 |
+
else:
|
1023 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
1024 |
+
last_timestamp_position = (
|
1025 |
+
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
1026 |
+
)
|
1027 |
+
seek += last_timestamp_position * self.input_stride
|
1028 |
+
|
1029 |
+
else:
|
1030 |
+
duration = segment_duration
|
1031 |
+
timestamps = [
|
1032 |
+
token for token in tokens if token >= tokenizer.timestamp_begin
|
1033 |
+
]
|
1034 |
+
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
1035 |
+
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
1036 |
+
duration = last_timestamp_position * self.time_precision
|
1037 |
+
|
1038 |
+
current_segments.append(
|
1039 |
+
dict(
|
1040 |
+
seek=seek,
|
1041 |
+
start=time_offset,
|
1042 |
+
end=time_offset + duration,
|
1043 |
+
tokens=tokens,
|
1044 |
+
)
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
seek += segment_size
|
1048 |
+
|
1049 |
+
return current_segments, seek, single_timestamp_ending
|
1050 |
+
|
1051 |
+
def generate_segments(
|
1052 |
+
self,
|
1053 |
+
features: np.ndarray,
|
1054 |
+
tokenizer: Tokenizer,
|
1055 |
+
options: TranscriptionOptions,
|
1056 |
+
log_progress,
|
1057 |
+
encoder_output: Optional[ctranslate2.StorageView] = None,
|
1058 |
+
) -> Iterable[Segment]:
|
1059 |
+
content_frames = features.shape[-1] - 1
|
1060 |
+
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
1061 |
+
|
1062 |
+
if isinstance(options.clip_timestamps, str):
|
1063 |
+
options.clip_timestamps = [
|
1064 |
+
float(ts)
|
1065 |
+
for ts in (
|
1066 |
+
options.clip_timestamps.split(",")
|
1067 |
+
if options.clip_timestamps
|
1068 |
+
else []
|
1069 |
+
)
|
1070 |
+
]
|
1071 |
+
|
1072 |
+
seek_points: List[int] = [
|
1073 |
+
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
1074 |
+
]
|
1075 |
+
if len(seek_points) == 0:
|
1076 |
+
seek_points.append(0)
|
1077 |
+
if len(seek_points) % 2 == 1:
|
1078 |
+
seek_points.append(content_frames)
|
1079 |
+
seek_clips: List[Tuple[int, int]] = list(
|
1080 |
+
zip(seek_points[::2], seek_points[1::2])
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
1084 |
+
|
1085 |
+
idx = 0
|
1086 |
+
clip_idx = 0
|
1087 |
+
seek = seek_clips[clip_idx][0]
|
1088 |
+
all_tokens = []
|
1089 |
+
prompt_reset_since = 0
|
1090 |
+
|
1091 |
+
if options.initial_prompt is not None:
|
1092 |
+
if isinstance(options.initial_prompt, str):
|
1093 |
+
initial_prompt = " " + options.initial_prompt.strip()
|
1094 |
+
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
1095 |
+
all_tokens.extend(initial_prompt_tokens)
|
1096 |
+
else:
|
1097 |
+
all_tokens.extend(options.initial_prompt)
|
1098 |
+
|
1099 |
+
pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
|
1100 |
+
last_speech_timestamp = 0.0
|
1101 |
+
all_segments = []
|
1102 |
+
# NOTE: This loop is obscurely flattened to make the diff readable.
|
1103 |
+
# A later commit should turn this into a simpler nested loop.
|
1104 |
+
# for seek_clip_start, seek_clip_end in seek_clips:
|
1105 |
+
# while seek < seek_clip_end
|
1106 |
+
while clip_idx < len(seek_clips):
|
1107 |
+
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
1108 |
+
if seek_clip_end > content_frames:
|
1109 |
+
seek_clip_end = content_frames
|
1110 |
+
if seek < seek_clip_start:
|
1111 |
+
seek = seek_clip_start
|
1112 |
+
if seek >= seek_clip_end:
|
1113 |
+
clip_idx += 1
|
1114 |
+
if clip_idx < len(seek_clips):
|
1115 |
+
seek = seek_clips[clip_idx][0]
|
1116 |
+
continue
|
1117 |
+
time_offset = seek * self.feature_extractor.time_per_frame
|
1118 |
+
window_end_time = float(
|
1119 |
+
(seek + self.feature_extractor.nb_max_frames)
|
1120 |
+
* self.feature_extractor.time_per_frame
|
1121 |
+
)
|
1122 |
+
segment_size = min(
|
1123 |
+
self.feature_extractor.nb_max_frames,
|
1124 |
+
content_frames - seek,
|
1125 |
+
seek_clip_end - seek,
|
1126 |
+
)
|
1127 |
+
segment = features[:, seek : seek + segment_size]
|
1128 |
+
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
1129 |
+
segment = pad_or_trim(segment)
|
1130 |
+
|
1131 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
1132 |
+
self.logger.debug(
|
1133 |
+
"Processing segment at %s", format_timestamp(time_offset)
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
previous_tokens = all_tokens[prompt_reset_since:]
|
1137 |
+
|
1138 |
+
if seek > 0 or encoder_output is None:
|
1139 |
+
encoder_output = self.encode(segment)
|
1140 |
+
|
1141 |
+
if options.multilingual:
|
1142 |
+
results = self.model.detect_language(encoder_output)
|
1143 |
+
language_token, language_probability = results[0][0]
|
1144 |
+
language = language_token[2:-2]
|
1145 |
+
|
1146 |
+
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
|
1147 |
+
tokenizer.language_code = language
|
1148 |
+
|
1149 |
+
prompt = self.get_prompt(
|
1150 |
+
tokenizer,
|
1151 |
+
previous_tokens,
|
1152 |
+
without_timestamps=options.without_timestamps,
|
1153 |
+
prefix=options.prefix if seek == 0 else None,
|
1154 |
+
hotwords=options.hotwords,
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
(
|
1158 |
+
result,
|
1159 |
+
avg_logprob,
|
1160 |
+
temperature,
|
1161 |
+
compression_ratio,
|
1162 |
+
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
1163 |
+
|
1164 |
+
if options.no_speech_threshold is not None:
|
1165 |
+
# no voice activity check
|
1166 |
+
should_skip = result.no_speech_prob > options.no_speech_threshold
|
1167 |
+
|
1168 |
+
if (
|
1169 |
+
options.log_prob_threshold is not None
|
1170 |
+
and avg_logprob > options.log_prob_threshold
|
1171 |
+
):
|
1172 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
1173 |
+
should_skip = False
|
1174 |
+
|
1175 |
+
if should_skip:
|
1176 |
+
self.logger.debug(
|
1177 |
+
"No speech threshold is met (%f > %f)",
|
1178 |
+
result.no_speech_prob,
|
1179 |
+
options.no_speech_threshold,
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
# fast-forward to the next segment boundary
|
1183 |
+
seek += segment_size
|
1184 |
+
continue
|
1185 |
+
|
1186 |
+
tokens = result.sequences_ids[0]
|
1187 |
+
|
1188 |
+
previous_seek = seek
|
1189 |
+
|
1190 |
+
# anomalous words are very long/short/improbable
|
1191 |
+
def word_anomaly_score(word: dict) -> float:
|
1192 |
+
probability = word.get("probability", 0.0)
|
1193 |
+
duration = word["end"] - word["start"]
|
1194 |
+
score = 0.0
|
1195 |
+
if probability < 0.15:
|
1196 |
+
score += 1.0
|
1197 |
+
if duration < 0.133:
|
1198 |
+
score += (0.133 - duration) * 15
|
1199 |
+
if duration > 2.0:
|
1200 |
+
score += duration - 2.0
|
1201 |
+
return score
|
1202 |
+
|
1203 |
+
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
1204 |
+
if segment is None or not segment["words"]:
|
1205 |
+
return False
|
1206 |
+
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
1207 |
+
words = words[:8]
|
1208 |
+
score = sum(word_anomaly_score(w) for w in words)
|
1209 |
+
return score >= 3 or score + 0.01 >= len(words)
|
1210 |
+
|
1211 |
+
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
1212 |
+
return next((s for s in segments if s["words"]), None)
|
1213 |
+
|
1214 |
+
(
|
1215 |
+
current_segments,
|
1216 |
+
seek,
|
1217 |
+
single_timestamp_ending,
|
1218 |
+
) = self._split_segments_by_timestamps(
|
1219 |
+
tokenizer=tokenizer,
|
1220 |
+
tokens=tokens,
|
1221 |
+
time_offset=time_offset,
|
1222 |
+
segment_size=segment_size,
|
1223 |
+
segment_duration=segment_duration,
|
1224 |
+
seek=seek,
|
1225 |
+
)
|
1226 |
+
|
1227 |
+
if options.word_timestamps:
|
1228 |
+
self.add_word_timestamps(
|
1229 |
+
[current_segments],
|
1230 |
+
tokenizer,
|
1231 |
+
encoder_output,
|
1232 |
+
segment_size,
|
1233 |
+
options.prepend_punctuations,
|
1234 |
+
options.append_punctuations,
|
1235 |
+
last_speech_timestamp=last_speech_timestamp,
|
1236 |
+
)
|
1237 |
+
if not single_timestamp_ending:
|
1238 |
+
last_word_end = get_end(current_segments)
|
1239 |
+
if last_word_end is not None and last_word_end > time_offset:
|
1240 |
+
seek = round(last_word_end * self.frames_per_second)
|
1241 |
+
|
1242 |
+
# skip silence before possible hallucinations
|
1243 |
+
if options.hallucination_silence_threshold is not None:
|
1244 |
+
threshold = options.hallucination_silence_threshold
|
1245 |
+
|
1246 |
+
# if first segment might be a hallucination, skip leading silence
|
1247 |
+
first_segment = next_words_segment(current_segments)
|
1248 |
+
if first_segment is not None and is_segment_anomaly(first_segment):
|
1249 |
+
gap = first_segment["start"] - time_offset
|
1250 |
+
if gap > threshold:
|
1251 |
+
seek = previous_seek + round(gap * self.frames_per_second)
|
1252 |
+
continue
|
1253 |
+
|
1254 |
+
# skip silence before any possible hallucination that is surrounded
|
1255 |
+
# by silence or more hallucinations
|
1256 |
+
hal_last_end = last_speech_timestamp
|
1257 |
+
for si in range(len(current_segments)):
|
1258 |
+
segment = current_segments[si]
|
1259 |
+
if not segment["words"]:
|
1260 |
+
continue
|
1261 |
+
if is_segment_anomaly(segment):
|
1262 |
+
next_segment = next_words_segment(
|
1263 |
+
current_segments[si + 1 :]
|
1264 |
+
)
|
1265 |
+
if next_segment is not None:
|
1266 |
+
hal_next_start = next_segment["words"][0]["start"]
|
1267 |
+
else:
|
1268 |
+
hal_next_start = time_offset + segment_duration
|
1269 |
+
silence_before = (
|
1270 |
+
segment["start"] - hal_last_end > threshold
|
1271 |
+
or segment["start"] < threshold
|
1272 |
+
or segment["start"] - time_offset < 2.0
|
1273 |
+
)
|
1274 |
+
silence_after = (
|
1275 |
+
hal_next_start - segment["end"] > threshold
|
1276 |
+
or is_segment_anomaly(next_segment)
|
1277 |
+
or window_end_time - segment["end"] < 2.0
|
1278 |
+
)
|
1279 |
+
if silence_before and silence_after:
|
1280 |
+
seek = round(
|
1281 |
+
max(time_offset + 1, segment["start"])
|
1282 |
+
* self.frames_per_second
|
1283 |
+
)
|
1284 |
+
if content_duration - segment["end"] < threshold:
|
1285 |
+
seek = content_frames
|
1286 |
+
current_segments[si:] = []
|
1287 |
+
break
|
1288 |
+
hal_last_end = segment["end"]
|
1289 |
+
|
1290 |
+
last_word_end = get_end(current_segments)
|
1291 |
+
if last_word_end is not None:
|
1292 |
+
last_speech_timestamp = last_word_end
|
1293 |
+
for segment in current_segments:
|
1294 |
+
tokens = segment["tokens"]
|
1295 |
+
text = tokenizer.decode(tokens)
|
1296 |
+
|
1297 |
+
if segment["start"] == segment["end"] or not text.strip():
|
1298 |
+
continue
|
1299 |
+
|
1300 |
+
all_tokens.extend(tokens)
|
1301 |
+
idx += 1
|
1302 |
+
|
1303 |
+
all_segments.append(Segment(
|
1304 |
+
id=idx,
|
1305 |
+
seek=previous_seek,
|
1306 |
+
start=segment["start"],
|
1307 |
+
end=segment["end"],
|
1308 |
+
text=text,
|
1309 |
+
tokens=tokens,
|
1310 |
+
temperature=temperature,
|
1311 |
+
avg_logprob=avg_logprob,
|
1312 |
+
compression_ratio=compression_ratio,
|
1313 |
+
no_speech_prob=result.no_speech_prob,
|
1314 |
+
words=(
|
1315 |
+
[Word(**word) for word in segment["words"]]
|
1316 |
+
if options.word_timestamps
|
1317 |
+
else None
|
1318 |
+
),
|
1319 |
+
))
|
1320 |
+
|
1321 |
+
if (
|
1322 |
+
not options.condition_on_previous_text
|
1323 |
+
or temperature > options.prompt_reset_on_temperature
|
1324 |
+
):
|
1325 |
+
if options.condition_on_previous_text:
|
1326 |
+
self.logger.debug(
|
1327 |
+
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
|
1328 |
+
temperature,
|
1329 |
+
options.prompt_reset_on_temperature,
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
prompt_reset_since = len(all_tokens)
|
1333 |
+
|
1334 |
+
pbar.update(
|
1335 |
+
(min(content_frames, seek) - previous_seek)
|
1336 |
+
* self.feature_extractor.time_per_frame,
|
1337 |
+
)
|
1338 |
+
pbar.close()
|
1339 |
+
return all_segments
|
1340 |
+
|
1341 |
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
1342 |
+
# When the model is running on multiple GPUs, the encoder output should be moved
|
1343 |
+
# to the CPU since we don't know which GPU will handle the next job.
|
1344 |
+
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
1345 |
+
|
1346 |
+
if features.ndim == 2:
|
1347 |
+
features = np.expand_dims(features, 0)
|
1348 |
+
features = get_ctranslate2_storage(features)
|
1349 |
+
|
1350 |
+
return self.model.encode(features, to_cpu=to_cpu)
|
1351 |
+
|
1352 |
+
def generate_with_fallback(
|
1353 |
+
self,
|
1354 |
+
encoder_output: ctranslate2.StorageView,
|
1355 |
+
prompt: List[int],
|
1356 |
+
tokenizer: Tokenizer,
|
1357 |
+
options: TranscriptionOptions,
|
1358 |
+
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
1359 |
+
decode_result = None
|
1360 |
+
all_results = []
|
1361 |
+
below_cr_threshold_results = []
|
1362 |
+
|
1363 |
+
max_initial_timestamp_index = int(
|
1364 |
+
round(options.max_initial_timestamp / self.time_precision)
|
1365 |
+
)
|
1366 |
+
if options.max_new_tokens is not None:
|
1367 |
+
max_length = len(prompt) + options.max_new_tokens
|
1368 |
+
else:
|
1369 |
+
max_length = self.max_length
|
1370 |
+
|
1371 |
+
if max_length > self.max_length:
|
1372 |
+
raise ValueError(
|
1373 |
+
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
1374 |
+
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
1375 |
+
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
1376 |
+
f"`max_length` of the Whisper model: {self.max_length}. "
|
1377 |
+
"You should either reduce the length of your prompt, or "
|
1378 |
+
"reduce the value of `max_new_tokens`, "
|
1379 |
+
f"so that their combined length is less that {self.max_length}."
|
1380 |
+
)
|
1381 |
+
|
1382 |
+
for temperature in options.temperatures:
|
1383 |
+
if temperature > 0:
|
1384 |
+
kwargs = {
|
1385 |
+
"beam_size": 1,
|
1386 |
+
"num_hypotheses": options.best_of,
|
1387 |
+
"sampling_topk": 0,
|
1388 |
+
"sampling_temperature": temperature,
|
1389 |
+
}
|
1390 |
+
else:
|
1391 |
+
kwargs = {
|
1392 |
+
"beam_size": options.beam_size,
|
1393 |
+
"patience": options.patience,
|
1394 |
+
}
|
1395 |
+
|
1396 |
+
result = self.model.generate(
|
1397 |
+
encoder_output,
|
1398 |
+
[prompt],
|
1399 |
+
length_penalty=options.length_penalty,
|
1400 |
+
repetition_penalty=options.repetition_penalty,
|
1401 |
+
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
1402 |
+
max_length=max_length,
|
1403 |
+
return_scores=True,
|
1404 |
+
return_no_speech_prob=True,
|
1405 |
+
suppress_blank=options.suppress_blank,
|
1406 |
+
suppress_tokens=options.suppress_tokens,
|
1407 |
+
max_initial_timestamp_index=max_initial_timestamp_index,
|
1408 |
+
**kwargs,
|
1409 |
+
)[0]
|
1410 |
+
|
1411 |
+
tokens = result.sequences_ids[0]
|
1412 |
+
|
1413 |
+
# Recover the average log prob from the returned score.
|
1414 |
+
seq_len = len(tokens)
|
1415 |
+
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
1416 |
+
avg_logprob = cum_logprob / (seq_len + 1)
|
1417 |
+
|
1418 |
+
text = tokenizer.decode(tokens).strip()
|
1419 |
+
compression_ratio = get_compression_ratio(text)
|
1420 |
+
|
1421 |
+
decode_result = (
|
1422 |
+
result,
|
1423 |
+
avg_logprob,
|
1424 |
+
temperature,
|
1425 |
+
compression_ratio,
|
1426 |
+
)
|
1427 |
+
all_results.append(decode_result)
|
1428 |
+
|
1429 |
+
needs_fallback = False
|
1430 |
+
|
1431 |
+
if options.compression_ratio_threshold is not None:
|
1432 |
+
if compression_ratio > options.compression_ratio_threshold:
|
1433 |
+
needs_fallback = True # too repetitive
|
1434 |
+
|
1435 |
+
self.logger.debug(
|
1436 |
+
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
|
1437 |
+
temperature,
|
1438 |
+
compression_ratio,
|
1439 |
+
options.compression_ratio_threshold,
|
1440 |
+
)
|
1441 |
+
else:
|
1442 |
+
below_cr_threshold_results.append(decode_result)
|
1443 |
+
|
1444 |
+
if (
|
1445 |
+
options.log_prob_threshold is not None
|
1446 |
+
and avg_logprob < options.log_prob_threshold
|
1447 |
+
):
|
1448 |
+
needs_fallback = True # average log probability is too low
|
1449 |
+
|
1450 |
+
self.logger.debug(
|
1451 |
+
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
1452 |
+
temperature,
|
1453 |
+
avg_logprob,
|
1454 |
+
options.log_prob_threshold,
|
1455 |
+
)
|
1456 |
+
|
1457 |
+
if (
|
1458 |
+
options.no_speech_threshold is not None
|
1459 |
+
and result.no_speech_prob > options.no_speech_threshold
|
1460 |
+
and options.log_prob_threshold is not None
|
1461 |
+
and avg_logprob < options.log_prob_threshold
|
1462 |
+
):
|
1463 |
+
needs_fallback = False # silence
|
1464 |
+
|
1465 |
+
if not needs_fallback:
|
1466 |
+
break
|
1467 |
+
else:
|
1468 |
+
# all failed, select the result with the highest average log probability
|
1469 |
+
decode_result = max(
|
1470 |
+
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
1471 |
+
)
|
1472 |
+
# to pass final temperature for prompt_reset_on_temperature
|
1473 |
+
decode_result = (
|
1474 |
+
decode_result[0],
|
1475 |
+
decode_result[1],
|
1476 |
+
temperature,
|
1477 |
+
decode_result[3],
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
return decode_result
|
1481 |
+
|
1482 |
+
def get_prompt(
|
1483 |
+
self,
|
1484 |
+
tokenizer: Tokenizer,
|
1485 |
+
previous_tokens: List[int],
|
1486 |
+
without_timestamps: bool = False,
|
1487 |
+
prefix: Optional[str] = None,
|
1488 |
+
hotwords: Optional[str] = None,
|
1489 |
+
) -> List[int]:
|
1490 |
+
prompt = []
|
1491 |
+
|
1492 |
+
if previous_tokens or (hotwords and not prefix):
|
1493 |
+
prompt.append(tokenizer.sot_prev)
|
1494 |
+
if hotwords and not prefix:
|
1495 |
+
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
|
1496 |
+
if len(hotwords_tokens) >= self.max_length // 2:
|
1497 |
+
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
|
1498 |
+
prompt.extend(hotwords_tokens)
|
1499 |
+
if previous_tokens:
|
1500 |
+
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
1501 |
+
|
1502 |
+
prompt.extend(tokenizer.sot_sequence)
|
1503 |
+
|
1504 |
+
if without_timestamps:
|
1505 |
+
prompt.append(tokenizer.no_timestamps)
|
1506 |
+
|
1507 |
+
if prefix:
|
1508 |
+
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
1509 |
+
if len(prefix_tokens) >= self.max_length // 2:
|
1510 |
+
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
1511 |
+
if not without_timestamps:
|
1512 |
+
prompt.append(tokenizer.timestamp_begin)
|
1513 |
+
prompt.extend(prefix_tokens)
|
1514 |
+
|
1515 |
+
return prompt
|
1516 |
+
|
1517 |
+
def add_word_timestamps(
|
1518 |
+
self,
|
1519 |
+
segments: List[dict],
|
1520 |
+
tokenizer: Tokenizer,
|
1521 |
+
encoder_output: ctranslate2.StorageView,
|
1522 |
+
num_frames: int,
|
1523 |
+
prepend_punctuations: str,
|
1524 |
+
append_punctuations: str,
|
1525 |
+
last_speech_timestamp: float,
|
1526 |
+
) -> float:
|
1527 |
+
if len(segments) == 0:
|
1528 |
+
return
|
1529 |
+
|
1530 |
+
text_tokens = []
|
1531 |
+
text_tokens_per_segment = []
|
1532 |
+
for segment in segments:
|
1533 |
+
segment_tokens = [
|
1534 |
+
[token for token in subsegment["tokens"] if token < tokenizer.eot]
|
1535 |
+
for subsegment in segment
|
1536 |
+
]
|
1537 |
+
text_tokens.append(list(itertools.chain.from_iterable(segment_tokens)))
|
1538 |
+
text_tokens_per_segment.append(segment_tokens)
|
1539 |
+
|
1540 |
+
alignments = self.find_alignment(
|
1541 |
+
tokenizer, text_tokens, encoder_output, num_frames
|
1542 |
+
)
|
1543 |
+
median_max_durations = []
|
1544 |
+
for alignment in alignments:
|
1545 |
+
word_durations = np.array(
|
1546 |
+
[word["end"] - word["start"] for word in alignment]
|
1547 |
+
)
|
1548 |
+
word_durations = word_durations[word_durations.nonzero()]
|
1549 |
+
median_duration = (
|
1550 |
+
np.median(word_durations) if len(word_durations) > 0 else 0.0
|
1551 |
+
)
|
1552 |
+
median_duration = min(0.7, float(median_duration))
|
1553 |
+
max_duration = median_duration * 2
|
1554 |
+
|
1555 |
+
# hack: truncate long words at sentence boundaries.
|
1556 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
1557 |
+
if len(word_durations) > 0:
|
1558 |
+
sentence_end_marks = ".。!!??"
|
1559 |
+
# ensure words at sentence boundaries
|
1560 |
+
# are not longer than twice the median word duration.
|
1561 |
+
for i in range(1, len(alignment)):
|
1562 |
+
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
|
1563 |
+
if alignment[i]["word"] in sentence_end_marks:
|
1564 |
+
alignment[i]["end"] = alignment[i]["start"] + max_duration
|
1565 |
+
elif alignment[i - 1]["word"] in sentence_end_marks:
|
1566 |
+
alignment[i]["start"] = alignment[i]["end"] - max_duration
|
1567 |
+
|
1568 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
1569 |
+
median_max_durations.append((median_duration, max_duration))
|
1570 |
+
|
1571 |
+
for segment_idx, segment in enumerate(segments):
|
1572 |
+
word_index = 0
|
1573 |
+
time_offset = segment[0]["seek"] / self.frames_per_second
|
1574 |
+
median_duration, max_duration = median_max_durations[segment_idx]
|
1575 |
+
for subsegment_idx, subsegment in enumerate(segment):
|
1576 |
+
saved_tokens = 0
|
1577 |
+
words = []
|
1578 |
+
|
1579 |
+
while word_index < len(alignments[segment_idx]) and saved_tokens < len(
|
1580 |
+
text_tokens_per_segment[segment_idx][subsegment_idx]
|
1581 |
+
):
|
1582 |
+
timing = alignments[segment_idx][word_index]
|
1583 |
+
|
1584 |
+
if timing["word"]:
|
1585 |
+
words.append(
|
1586 |
+
dict(
|
1587 |
+
word=timing["word"],
|
1588 |
+
start=round(time_offset + timing["start"], 2),
|
1589 |
+
end=round(time_offset + timing["end"], 2),
|
1590 |
+
probability=timing["probability"],
|
1591 |
+
)
|
1592 |
+
)
|
1593 |
+
|
1594 |
+
saved_tokens += len(timing["tokens"])
|
1595 |
+
word_index += 1
|
1596 |
+
|
1597 |
+
# hack: truncate long words at segment boundaries.
|
1598 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
1599 |
+
if len(words) > 0:
|
1600 |
+
# ensure the first and second word after a pause is not longer than
|
1601 |
+
# twice the median word duration.
|
1602 |
+
if words[0][
|
1603 |
+
"end"
|
1604 |
+
] - last_speech_timestamp > median_duration * 4 and (
|
1605 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
1606 |
+
or (
|
1607 |
+
len(words) > 1
|
1608 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
1609 |
+
)
|
1610 |
+
):
|
1611 |
+
if (
|
1612 |
+
len(words) > 1
|
1613 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
1614 |
+
):
|
1615 |
+
boundary = max(
|
1616 |
+
words[1]["end"] / 2, words[1]["end"] - max_duration
|
1617 |
+
)
|
1618 |
+
words[0]["end"] = words[1]["start"] = boundary
|
1619 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
1620 |
+
|
1621 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
1622 |
+
if (
|
1623 |
+
subsegment["start"] < words[0]["end"]
|
1624 |
+
and subsegment["start"] - 0.5 > words[0]["start"]
|
1625 |
+
):
|
1626 |
+
words[0]["start"] = max(
|
1627 |
+
0,
|
1628 |
+
min(words[0]["end"] - median_duration, subsegment["start"]),
|
1629 |
+
)
|
1630 |
+
else:
|
1631 |
+
subsegment["start"] = words[0]["start"]
|
1632 |
+
|
1633 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
1634 |
+
if (
|
1635 |
+
subsegment["end"] > words[-1]["start"]
|
1636 |
+
and subsegment["end"] + 0.5 < words[-1]["end"]
|
1637 |
+
):
|
1638 |
+
words[-1]["end"] = max(
|
1639 |
+
words[-1]["start"] + median_duration, subsegment["end"]
|
1640 |
+
)
|
1641 |
+
else:
|
1642 |
+
subsegment["end"] = words[-1]["end"]
|
1643 |
+
|
1644 |
+
last_speech_timestamp = subsegment["end"]
|
1645 |
+
segments[segment_idx][subsegment_idx]["words"] = words
|
1646 |
+
return last_speech_timestamp
|
1647 |
+
|
1648 |
+
def find_alignment(
|
1649 |
+
self,
|
1650 |
+
tokenizer: Tokenizer,
|
1651 |
+
text_tokens: List[int],
|
1652 |
+
encoder_output: ctranslate2.StorageView,
|
1653 |
+
num_frames: int,
|
1654 |
+
median_filter_width: int = 7,
|
1655 |
+
) -> List[dict]:
|
1656 |
+
if len(text_tokens) == 0:
|
1657 |
+
return []
|
1658 |
+
|
1659 |
+
results = self.model.align(
|
1660 |
+
encoder_output,
|
1661 |
+
tokenizer.sot_sequence,
|
1662 |
+
text_tokens,
|
1663 |
+
num_frames,
|
1664 |
+
median_filter_width=median_filter_width,
|
1665 |
+
)
|
1666 |
+
return_list = []
|
1667 |
+
for result, text_token in zip(results, text_tokens):
|
1668 |
+
text_token_probs = result.text_token_probs
|
1669 |
+
alignments = result.alignments
|
1670 |
+
text_indices = np.array([pair[0] for pair in alignments])
|
1671 |
+
time_indices = np.array([pair[1] for pair in alignments])
|
1672 |
+
|
1673 |
+
words, word_tokens = tokenizer.split_to_word_tokens(
|
1674 |
+
text_token + [tokenizer.eot]
|
1675 |
+
)
|
1676 |
+
if len(word_tokens) <= 1:
|
1677 |
+
# return on eot only
|
1678 |
+
# >>> np.pad([], (1, 0))
|
1679 |
+
# array([0.])
|
1680 |
+
# This results in crashes when we lookup jump_times with float, like
|
1681 |
+
# IndexError: arrays used as indices must be of integer (or boolean) type
|
1682 |
+
return_list.append([])
|
1683 |
+
continue
|
1684 |
+
word_boundaries = np.pad(
|
1685 |
+
np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)
|
1686 |
+
)
|
1687 |
+
if len(word_boundaries) <= 1:
|
1688 |
+
return_list.append([])
|
1689 |
+
continue
|
1690 |
+
|
1691 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(
|
1692 |
+
bool
|
1693 |
+
)
|
1694 |
+
jump_times = time_indices[jumps] / self.tokens_per_second
|
1695 |
+
start_times = jump_times[word_boundaries[:-1]]
|
1696 |
+
end_times = jump_times[word_boundaries[1:]]
|
1697 |
+
word_probabilities = [
|
1698 |
+
np.mean(text_token_probs[i:j])
|
1699 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
1700 |
+
]
|
1701 |
+
|
1702 |
+
return_list.append(
|
1703 |
+
[
|
1704 |
+
dict(
|
1705 |
+
word=word,
|
1706 |
+
tokens=tokens,
|
1707 |
+
start=start,
|
1708 |
+
end=end,
|
1709 |
+
probability=probability,
|
1710 |
+
)
|
1711 |
+
for word, tokens, start, end, probability in zip(
|
1712 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
1713 |
+
)
|
1714 |
+
]
|
1715 |
+
)
|
1716 |
+
return return_list
|
1717 |
+
|
1718 |
+
def detect_language(
|
1719 |
+
self,
|
1720 |
+
audio: Optional[np.ndarray] = None,
|
1721 |
+
features: Optional[np.ndarray] = None,
|
1722 |
+
vad_filter: bool = False,
|
1723 |
+
vad_parameters: Union[dict, VadOptions] = None,
|
1724 |
+
language_detection_segments: int = 1,
|
1725 |
+
language_detection_threshold: float = 0.5,
|
1726 |
+
) -> Tuple[str, float, List[Tuple[str, float]]]:
|
1727 |
+
"""
|
1728 |
+
Use Whisper to detect the language of the input audio or features.
|
1729 |
+
|
1730 |
+
Arguments:
|
1731 |
+
audio: Input audio signal, must be a 1D float array sampled at 16khz.
|
1732 |
+
features: Input Mel spectrogram features, must be a float array with
|
1733 |
+
shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
|
1734 |
+
Either `audio` or `features` must be provided.
|
1735 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
1736 |
+
without speech. This step is using the Silero VAD model.
|
1737 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
1738 |
+
parameters and default values in the class `VadOptions`).
|
1739 |
+
language_detection_threshold: If the maximum probability of the language tokens is
|
1740 |
+
higher than this value, the language is detected.
|
1741 |
+
language_detection_segments: Number of segments to consider for the language detection.
|
1742 |
+
|
1743 |
+
Returns:
|
1744 |
+
language: Detected language.
|
1745 |
+
languege_probability: Probability of the detected language.
|
1746 |
+
all_language_probs: List of tuples with all language names and probabilities.
|
1747 |
+
"""
|
1748 |
+
assert (
|
1749 |
+
audio is not None or features is not None
|
1750 |
+
), "Either `audio` or `features` must be provided."
|
1751 |
+
|
1752 |
+
if audio is not None:
|
1753 |
+
if vad_filter:
|
1754 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
1755 |
+
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
1756 |
+
audio = np.concatenate(audio_chunks, axis=0)
|
1757 |
+
|
1758 |
+
audio = audio[
|
1759 |
+
: language_detection_segments * self.feature_extractor.n_samples
|
1760 |
+
]
|
1761 |
+
features = self.feature_extractor(audio)
|
1762 |
+
|
1763 |
+
features = features[
|
1764 |
+
..., : language_detection_segments * self.feature_extractor.nb_max_frames
|
1765 |
+
]
|
1766 |
+
|
1767 |
+
detected_language_info = {}
|
1768 |
+
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
|
1769 |
+
encoder_output = self.encode(
|
1770 |
+
pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
|
1771 |
+
)
|
1772 |
+
# results is a list of tuple[str, float] with language names and probabilities.
|
1773 |
+
results = self.model.detect_language(encoder_output)[0]
|
1774 |
+
|
1775 |
+
# Parse language names to strip out markers
|
1776 |
+
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
|
1777 |
+
# Get top language token and probability
|
1778 |
+
language, language_probability = all_language_probs[0]
|
1779 |
+
if language_probability > language_detection_threshold:
|
1780 |
+
break
|
1781 |
+
detected_language_info.setdefault(language, []).append(language_probability)
|
1782 |
+
else:
|
1783 |
+
# If no language detected for all segments, the majority vote of the highest
|
1784 |
+
# projected languages for all segments is used to determine the language.
|
1785 |
+
language = max(
|
1786 |
+
detected_language_info,
|
1787 |
+
key=lambda lang: len(detected_language_info[lang]),
|
1788 |
+
)
|
1789 |
+
language_probability = max(detected_language_info[language])
|
1790 |
+
|
1791 |
+
return language, language_probability, all_language_probs
|
1792 |
+
|
1793 |
+
|
1794 |
+
def restore_speech_timestamps(
|
1795 |
+
segments: Iterable[Segment],
|
1796 |
+
speech_chunks: List[dict],
|
1797 |
+
sampling_rate: int,
|
1798 |
+
) -> Iterable[Segment]:
|
1799 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
1800 |
+
|
1801 |
+
for segment in segments:
|
1802 |
+
if segment.words:
|
1803 |
+
words = []
|
1804 |
+
for word in segment.words:
|
1805 |
+
# Ensure the word start and end times are resolved to the same chunk.
|
1806 |
+
middle = (word.start + word.end) / 2
|
1807 |
+
chunk_index = ts_map.get_chunk_index(middle)
|
1808 |
+
word.start = ts_map.get_original_time(word.start, chunk_index)
|
1809 |
+
word.end = ts_map.get_original_time(word.end, chunk_index)
|
1810 |
+
words.append(word)
|
1811 |
+
|
1812 |
+
segment.start = words[0].start
|
1813 |
+
segment.end = words[-1].end
|
1814 |
+
segment.words = words
|
1815 |
+
|
1816 |
+
else:
|
1817 |
+
segment.start = ts_map.get_original_time(segment.start)
|
1818 |
+
segment.end = ts_map.get_original_time(segment.end)
|
1819 |
+
return segments
|
1820 |
+
|
1821 |
+
|
1822 |
+
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
1823 |
+
segment = np.ascontiguousarray(segment)
|
1824 |
+
segment = ctranslate2.StorageView.from_array(segment)
|
1825 |
+
return segment
|
1826 |
+
|
1827 |
+
|
1828 |
+
def get_compression_ratio(text: str) -> float:
|
1829 |
+
text_bytes = text.encode("utf-8")
|
1830 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
1831 |
+
|
1832 |
+
|
1833 |
+
def get_suppressed_tokens(
|
1834 |
+
tokenizer: Tokenizer,
|
1835 |
+
suppress_tokens: Tuple[int],
|
1836 |
+
) -> Optional[List[int]]:
|
1837 |
+
if -1 in suppress_tokens:
|
1838 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
1839 |
+
suppress_tokens.extend(tokenizer.non_speech_tokens)
|
1840 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
1841 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
1842 |
+
else:
|
1843 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
1844 |
+
|
1845 |
+
suppress_tokens.extend(
|
1846 |
+
[
|
1847 |
+
tokenizer.transcribe,
|
1848 |
+
tokenizer.translate,
|
1849 |
+
tokenizer.sot,
|
1850 |
+
tokenizer.sot_prev,
|
1851 |
+
tokenizer.sot_lm,
|
1852 |
+
]
|
1853 |
+
)
|
1854 |
+
|
1855 |
+
return tuple(sorted(set(suppress_tokens)))
|
1856 |
+
|
1857 |
+
|
1858 |
+
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
1859 |
+
# merge prepended punctuations
|
1860 |
+
i = len(alignment) - 2
|
1861 |
+
j = len(alignment) - 1
|
1862 |
+
while i >= 0:
|
1863 |
+
previous = alignment[i]
|
1864 |
+
following = alignment[j]
|
1865 |
+
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
1866 |
+
# prepend it to the following word
|
1867 |
+
following["word"] = previous["word"] + following["word"]
|
1868 |
+
following["tokens"] = previous["tokens"] + following["tokens"]
|
1869 |
+
previous["word"] = ""
|
1870 |
+
previous["tokens"] = []
|
1871 |
+
else:
|
1872 |
+
j = i
|
1873 |
+
i -= 1
|
1874 |
+
|
1875 |
+
# merge appended punctuations
|
1876 |
+
i = 0
|
1877 |
+
j = 1
|
1878 |
+
while j < len(alignment):
|
1879 |
+
previous = alignment[i]
|
1880 |
+
following = alignment[j]
|
1881 |
+
if not previous["word"].endswith(" ") and following["word"] in appended:
|
1882 |
+
# append it to the previous word
|
1883 |
+
previous["word"] = previous["word"] + following["word"]
|
1884 |
+
previous["tokens"] = previous["tokens"] + following["tokens"]
|
1885 |
+
following["word"] = ""
|
1886 |
+
following["tokens"] = []
|
1887 |
+
else:
|
1888 |
+
i = j
|
1889 |
+
j += 1
|
whisper_live/transcriber/transcriber_openvino.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import os
|
3 |
+
|
4 |
+
import openvino_genai as ov_genai
|
5 |
+
import huggingface_hub as hf_hub
|
6 |
+
|
7 |
+
|
8 |
+
class WhisperOpenVINO(object):
|
9 |
+
def __init__(self, model_id="OpenVINO/whisper-tiny-fp16-ov", device="CPU", language="en", task="transcribe"):
|
10 |
+
model_path = model_id.split('/')[-1]
|
11 |
+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "openvino_whisper_models")
|
12 |
+
os.makedirs(cache_dir, exist_ok=True)
|
13 |
+
model_path = os.path.join(cache_dir, model_path)
|
14 |
+
if not os.path.exists(model_path):
|
15 |
+
hf_hub.snapshot_download(model_id, local_dir=model_path)
|
16 |
+
self.model = ov_genai.WhisperPipeline(str(model_path), device=device)
|
17 |
+
self.language = language
|
18 |
+
self.task = task
|
19 |
+
|
20 |
+
def transcribe(self, input_audio):
|
21 |
+
outputs = self.model.generate(input_audio, return_timestamps=True, language=self.language, task=self.task)
|
22 |
+
outputs = [seg for seg in outputs.chunks]
|
23 |
+
return outputs
|
whisper_live/transcriber/transcriber_tensorrt.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from whisper.tokenizer import get_tokenizer
|
12 |
+
from whisper_live.transcriber.tensorrt_utils import (
|
13 |
+
mel_filters,
|
14 |
+
load_audio_wav_format,
|
15 |
+
pad_or_trim,
|
16 |
+
load_audio
|
17 |
+
)
|
18 |
+
|
19 |
+
import tensorrt_llm
|
20 |
+
import tensorrt_llm.logger as logger
|
21 |
+
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
22 |
+
trt_dtype_to_torch)
|
23 |
+
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType
|
24 |
+
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
|
25 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
26 |
+
if PYTHON_BINDINGS:
|
27 |
+
from tensorrt_llm.runtime import ModelRunnerCpp
|
28 |
+
|
29 |
+
SAMPLE_RATE = 16000
|
30 |
+
N_FFT = 400
|
31 |
+
HOP_LENGTH = 160
|
32 |
+
CHUNK_LENGTH = 30
|
33 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
34 |
+
|
35 |
+
def read_config(component, engine_dir):
|
36 |
+
config_path = engine_dir / component / 'config.json'
|
37 |
+
with open(config_path, 'r') as f:
|
38 |
+
config = json.load(f)
|
39 |
+
model_config = OrderedDict()
|
40 |
+
model_config.update(config['pretrained_config'])
|
41 |
+
model_config.update(config['build_config'])
|
42 |
+
return model_config
|
43 |
+
|
44 |
+
|
45 |
+
def remove_tensor_padding(input_tensor,
|
46 |
+
input_tensor_lengths=None,
|
47 |
+
pad_value=None):
|
48 |
+
if pad_value:
|
49 |
+
assert input_tensor_lengths is None, "input_tensor_lengths should be None when pad_value is provided"
|
50 |
+
# Text tensor case: batch, seq_len
|
51 |
+
assert torch.all(
|
52 |
+
input_tensor[:, 0] != pad_value
|
53 |
+
), "First token in each sequence should not be pad_value"
|
54 |
+
assert input_tensor_lengths is None
|
55 |
+
|
56 |
+
# Create a mask for all non-pad tokens
|
57 |
+
mask = input_tensor != pad_value
|
58 |
+
|
59 |
+
# Apply the mask to input_tensor to remove pad tokens
|
60 |
+
output_tensor = input_tensor[mask].view(1, -1)
|
61 |
+
|
62 |
+
else:
|
63 |
+
# Audio tensor case: batch, seq_len, feature_len
|
64 |
+
# position_ids case: batch, seq_len
|
65 |
+
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
66 |
+
|
67 |
+
# Initialize a list to collect valid sequences
|
68 |
+
valid_sequences = []
|
69 |
+
|
70 |
+
for i in range(input_tensor.shape[0]):
|
71 |
+
valid_length = input_tensor_lengths[i]
|
72 |
+
valid_sequences.append(input_tensor[i, :valid_length])
|
73 |
+
|
74 |
+
# Concatenate all valid sequences along the batch dimension
|
75 |
+
output_tensor = torch.cat(valid_sequences, dim=0)
|
76 |
+
return output_tensor
|
77 |
+
|
78 |
+
|
79 |
+
class WhisperEncoding:
|
80 |
+
|
81 |
+
def __init__(self, engine_dir):
|
82 |
+
self.session = self.get_session(engine_dir)
|
83 |
+
config = read_config('encoder', engine_dir)
|
84 |
+
self.n_mels = config['n_mels']
|
85 |
+
self.dtype = config['dtype']
|
86 |
+
self.num_languages = config['num_languages']
|
87 |
+
self.encoder_config = config
|
88 |
+
|
89 |
+
def get_session(self, engine_dir):
|
90 |
+
serialize_path = engine_dir / 'encoder' / 'rank0.engine'
|
91 |
+
with open(serialize_path, 'rb') as f:
|
92 |
+
session = Session.from_serialized_engine(f.read())
|
93 |
+
return session
|
94 |
+
|
95 |
+
def get_audio_features(self,
|
96 |
+
mel,
|
97 |
+
mel_input_lengths,
|
98 |
+
encoder_downsampling_factor=2):
|
99 |
+
if isinstance(mel, list):
|
100 |
+
longest_mel = max([f.shape[-1] for f in mel])
|
101 |
+
mel = [
|
102 |
+
torch.nn.functional.pad(f, (0, longest_mel - f.shape[-1]),
|
103 |
+
mode='constant') for f in mel
|
104 |
+
]
|
105 |
+
mel = torch.cat(mel, dim=0).type(
|
106 |
+
str_dtype_to_torch("float16")).contiguous()
|
107 |
+
bsz, seq_len = mel.shape[0], mel.shape[2]
|
108 |
+
position_ids = torch.arange(
|
109 |
+
math.ceil(seq_len / encoder_downsampling_factor),
|
110 |
+
dtype=torch.int32,
|
111 |
+
device=mel.device).expand(bsz, -1).contiguous()
|
112 |
+
if self.encoder_config['plugin_config']['remove_input_padding']:
|
113 |
+
# mel B,D,T -> B,T,D -> BxT, D
|
114 |
+
mel = mel.transpose(1, 2)
|
115 |
+
mel = remove_tensor_padding(mel, mel_input_lengths)
|
116 |
+
position_ids = remove_tensor_padding(
|
117 |
+
position_ids, mel_input_lengths // encoder_downsampling_factor)
|
118 |
+
inputs = OrderedDict()
|
119 |
+
inputs['input_features'] = mel
|
120 |
+
inputs['input_lengths'] = mel_input_lengths
|
121 |
+
inputs['position_ids'] = position_ids
|
122 |
+
|
123 |
+
output_list = [
|
124 |
+
TensorInfo('input_features', str_dtype_to_trt(self.dtype),
|
125 |
+
mel.shape),
|
126 |
+
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
|
127 |
+
mel_input_lengths.shape),
|
128 |
+
TensorInfo('position_ids', str_dtype_to_trt('int32'),
|
129 |
+
inputs['position_ids'].shape)
|
130 |
+
]
|
131 |
+
|
132 |
+
output_info = (self.session).infer_shapes(output_list)
|
133 |
+
|
134 |
+
logger.debug(f'output info {output_info}')
|
135 |
+
outputs = {
|
136 |
+
t.name: torch.empty(tuple(t.shape),
|
137 |
+
dtype=trt_dtype_to_torch(t.dtype),
|
138 |
+
device='cuda')
|
139 |
+
for t in output_info
|
140 |
+
}
|
141 |
+
stream = torch.cuda.current_stream()
|
142 |
+
ok = self.session.run(inputs=inputs,
|
143 |
+
outputs=outputs,
|
144 |
+
stream=stream.cuda_stream)
|
145 |
+
assert ok, 'Engine execution failed'
|
146 |
+
stream.synchronize()
|
147 |
+
encoder_output = outputs['encoder_output']
|
148 |
+
encoder_output_lengths = mel_input_lengths // encoder_downsampling_factor
|
149 |
+
return encoder_output, encoder_output_lengths
|
150 |
+
|
151 |
+
|
152 |
+
class WhisperDecoding:
|
153 |
+
|
154 |
+
def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
|
155 |
+
|
156 |
+
self.decoder_config = read_config('decoder', engine_dir)
|
157 |
+
self.decoder_generation_session = self.get_session(
|
158 |
+
engine_dir, runtime_mapping, debug_mode)
|
159 |
+
|
160 |
+
def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
|
161 |
+
serialize_path = engine_dir / 'decoder' / 'rank0.engine'
|
162 |
+
with open(serialize_path, "rb") as f:
|
163 |
+
decoder_engine_buffer = f.read()
|
164 |
+
|
165 |
+
decoder_model_config = ModelConfig(
|
166 |
+
max_batch_size=self.decoder_config['max_batch_size'],
|
167 |
+
max_beam_width=self.decoder_config['max_beam_width'],
|
168 |
+
num_heads=self.decoder_config['num_attention_heads'],
|
169 |
+
num_kv_heads=self.decoder_config['num_attention_heads'],
|
170 |
+
hidden_size=self.decoder_config['hidden_size'],
|
171 |
+
vocab_size=self.decoder_config['vocab_size'],
|
172 |
+
cross_attention=True,
|
173 |
+
num_layers=self.decoder_config['num_hidden_layers'],
|
174 |
+
gpt_attention_plugin=self.decoder_config['plugin_config']
|
175 |
+
['gpt_attention_plugin'],
|
176 |
+
remove_input_padding=self.decoder_config['plugin_config']
|
177 |
+
['remove_input_padding'],
|
178 |
+
kv_cache_type=KVCacheType.PAGED
|
179 |
+
if self.decoder_config['plugin_config']['paged_kv_cache'] == True
|
180 |
+
else KVCacheType.CONTINUOUS,
|
181 |
+
has_position_embedding=self.
|
182 |
+
decoder_config['has_position_embedding'],
|
183 |
+
dtype=self.decoder_config['dtype'],
|
184 |
+
has_token_type_embedding=False,
|
185 |
+
)
|
186 |
+
decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
|
187 |
+
decoder_model_config,
|
188 |
+
decoder_engine_buffer,
|
189 |
+
runtime_mapping,
|
190 |
+
debug_mode=debug_mode)
|
191 |
+
|
192 |
+
return decoder_generation_session
|
193 |
+
|
194 |
+
def generate(self,
|
195 |
+
decoder_input_ids,
|
196 |
+
encoder_outputs,
|
197 |
+
encoder_max_input_length,
|
198 |
+
encoder_input_lengths,
|
199 |
+
eot_id,
|
200 |
+
max_new_tokens=40,
|
201 |
+
num_beams=1):
|
202 |
+
batch_size = decoder_input_ids.shape[0]
|
203 |
+
decoder_input_lengths = torch.tensor([
|
204 |
+
decoder_input_ids.shape[-1]
|
205 |
+
for _ in range(decoder_input_ids.shape[0])
|
206 |
+
],
|
207 |
+
dtype=torch.int32,
|
208 |
+
device='cuda')
|
209 |
+
decoder_max_input_length = torch.max(decoder_input_lengths).item()
|
210 |
+
|
211 |
+
cross_attention_mask = torch.ones([
|
212 |
+
batch_size, decoder_max_input_length + max_new_tokens,
|
213 |
+
encoder_max_input_length
|
214 |
+
]).int().cuda()
|
215 |
+
# generation config
|
216 |
+
sampling_config = SamplingConfig(end_id=eot_id,
|
217 |
+
pad_id=eot_id,
|
218 |
+
num_beams=num_beams)
|
219 |
+
self.decoder_generation_session.setup(
|
220 |
+
decoder_input_lengths.size(0),
|
221 |
+
decoder_max_input_length,
|
222 |
+
max_new_tokens,
|
223 |
+
beam_width=num_beams,
|
224 |
+
encoder_max_input_length=encoder_max_input_length)
|
225 |
+
|
226 |
+
torch.cuda.synchronize()
|
227 |
+
|
228 |
+
decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
|
229 |
+
if self.decoder_config['plugin_config']['remove_input_padding']:
|
230 |
+
# 50256 is the index of <pad> for all whisper models' decoder
|
231 |
+
WHISPER_PAD_TOKEN_ID = 50256
|
232 |
+
decoder_input_ids = remove_tensor_padding(
|
233 |
+
decoder_input_ids, pad_value=WHISPER_PAD_TOKEN_ID)
|
234 |
+
if encoder_outputs.dim() == 3:
|
235 |
+
encoder_output_lens = torch.full((encoder_outputs.shape[0], ),
|
236 |
+
encoder_outputs.shape[1],
|
237 |
+
dtype=torch.int32,
|
238 |
+
device='cuda')
|
239 |
+
|
240 |
+
encoder_outputs = remove_tensor_padding(encoder_outputs,
|
241 |
+
encoder_output_lens)
|
242 |
+
output_ids = self.decoder_generation_session.decode(
|
243 |
+
decoder_input_ids,
|
244 |
+
decoder_input_lengths,
|
245 |
+
sampling_config,
|
246 |
+
encoder_output=encoder_outputs,
|
247 |
+
encoder_input_lengths=encoder_input_lengths,
|
248 |
+
cross_attention_mask=cross_attention_mask,
|
249 |
+
)
|
250 |
+
torch.cuda.synchronize()
|
251 |
+
|
252 |
+
# get the list of int from output_ids tensor
|
253 |
+
output_ids = output_ids.cpu().numpy().tolist()
|
254 |
+
return output_ids
|
255 |
+
|
256 |
+
|
257 |
+
class WhisperTRTLLM(object):
|
258 |
+
|
259 |
+
def __init__(self,
|
260 |
+
engine_dir,
|
261 |
+
assets_dir=None,
|
262 |
+
device=None,
|
263 |
+
is_multilingual=False,
|
264 |
+
language="en",
|
265 |
+
task="transcribe",
|
266 |
+
use_py_session=False,
|
267 |
+
num_beams=1,
|
268 |
+
debug_mode=False,
|
269 |
+
max_output_len=96):
|
270 |
+
world_size = 1
|
271 |
+
runtime_rank = tensorrt_llm.mpi_rank()
|
272 |
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
273 |
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
274 |
+
engine_dir = Path(engine_dir)
|
275 |
+
encoder_config = read_config('encoder', engine_dir)
|
276 |
+
decoder_config = read_config('decoder', engine_dir)
|
277 |
+
self.n_mels = encoder_config['n_mels']
|
278 |
+
self.num_languages = encoder_config['num_languages']
|
279 |
+
is_multilingual = (decoder_config['vocab_size'] >= 51865)
|
280 |
+
|
281 |
+
self.device = device
|
282 |
+
self.tokenizer = get_tokenizer(
|
283 |
+
is_multilingual,
|
284 |
+
num_languages=self.num_languages,
|
285 |
+
language=language,
|
286 |
+
task=task,
|
287 |
+
)
|
288 |
+
|
289 |
+
if use_py_session:
|
290 |
+
self.encoder = WhisperEncoding(engine_dir)
|
291 |
+
self.decoder = WhisperDecoding(engine_dir,
|
292 |
+
runtime_mapping,
|
293 |
+
debug_mode=False)
|
294 |
+
else:
|
295 |
+
json_config = GptJsonConfig.parse_file(engine_dir / 'decoder' /
|
296 |
+
'config.json')
|
297 |
+
assert json_config.model_config.supports_inflight_batching
|
298 |
+
runner_kwargs = dict(engine_dir=engine_dir,
|
299 |
+
is_enc_dec=True,
|
300 |
+
max_batch_size=1,
|
301 |
+
max_input_len=3000,
|
302 |
+
max_output_len=max_output_len,
|
303 |
+
max_beam_width=num_beams,
|
304 |
+
debug_mode=debug_mode,
|
305 |
+
kv_cache_free_gpu_memory_fraction=0.9,
|
306 |
+
cross_kv_cache_fraction=0.5)
|
307 |
+
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
|
308 |
+
self.filters = mel_filters(self.device, self.n_mels, assets_dir)
|
309 |
+
self.use_py_session = use_py_session
|
310 |
+
|
311 |
+
def log_mel_spectrogram(
|
312 |
+
self,
|
313 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
314 |
+
padding: int = 0,
|
315 |
+
return_duration=True
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Compute the log-Mel spectrogram of
|
319 |
+
|
320 |
+
Parameters
|
321 |
+
----------
|
322 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
323 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
324 |
+
|
325 |
+
n_mels: int
|
326 |
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
327 |
+
|
328 |
+
padding: int
|
329 |
+
Number of zero samples to pad to the right
|
330 |
+
|
331 |
+
device: Optional[Union[str, torch.device]]
|
332 |
+
If given, the audio tensor is moved to this device before STFT
|
333 |
+
|
334 |
+
Returns
|
335 |
+
-------
|
336 |
+
torch.Tensor, shape = (80 or 128, n_frames)
|
337 |
+
A Tensor that contains the Mel spectrogram
|
338 |
+
"""
|
339 |
+
if not torch.is_tensor(audio):
|
340 |
+
if isinstance(audio, str):
|
341 |
+
if audio.endswith('.wav'):
|
342 |
+
audio, _ = load_audio_wav_format(audio)
|
343 |
+
else:
|
344 |
+
audio = load_audio(audio)
|
345 |
+
assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}"
|
346 |
+
duration = audio.shape[-1] / SAMPLE_RATE
|
347 |
+
audio = pad_or_trim(audio, N_SAMPLES)
|
348 |
+
audio = audio.astype(np.float32)
|
349 |
+
audio = torch.from_numpy(audio)
|
350 |
+
|
351 |
+
if self.device is not None:
|
352 |
+
audio = audio.to(self.device)
|
353 |
+
if padding > 0:
|
354 |
+
audio = F.pad(audio, (0, padding))
|
355 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
356 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
357 |
+
magnitudes = stft[..., :-1].abs()**2
|
358 |
+
|
359 |
+
mel_spec = self.filters @ magnitudes
|
360 |
+
|
361 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
362 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
363 |
+
log_spec = (log_spec + 4.0) / 4.0
|
364 |
+
if return_duration:
|
365 |
+
return log_spec, duration
|
366 |
+
else:
|
367 |
+
return log_spec
|
368 |
+
|
369 |
+
def process_batch(
|
370 |
+
self,
|
371 |
+
mel,
|
372 |
+
mel_input_lengths,
|
373 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
374 |
+
num_beams=1,
|
375 |
+
max_new_tokens=96):
|
376 |
+
prompt_id = self.tokenizer.encode(
|
377 |
+
text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys()))
|
378 |
+
|
379 |
+
prompt_id = torch.tensor(prompt_id)
|
380 |
+
batch_size = mel.shape[0]
|
381 |
+
decoder_input_ids = prompt_id.repeat(batch_size, 1)
|
382 |
+
if self.use_py_session:
|
383 |
+
encoder_output, encoder_output_lengths = self.encoder.get_audio_features(mel, mel_input_lengths)
|
384 |
+
encoder_max_input_length = torch.max(encoder_output_lengths).item()
|
385 |
+
output_ids = self.decoder.generate(decoder_input_ids,
|
386 |
+
encoder_output,
|
387 |
+
encoder_max_input_length,
|
388 |
+
encoder_output_lengths,
|
389 |
+
self.tokenizer.eot,
|
390 |
+
max_new_tokens=max_new_tokens,
|
391 |
+
num_beams=num_beams)
|
392 |
+
else:
|
393 |
+
with torch.no_grad():
|
394 |
+
if isinstance(mel, list):
|
395 |
+
mel = [
|
396 |
+
m.transpose(1, 2).type(
|
397 |
+
str_dtype_to_torch("float16")).squeeze(0)
|
398 |
+
for m in mel
|
399 |
+
]
|
400 |
+
else:
|
401 |
+
mel = mel.transpose(1, 2)
|
402 |
+
outputs = self.model_runner_cpp.generate(
|
403 |
+
batch_input_ids=decoder_input_ids,
|
404 |
+
encoder_input_features=mel,
|
405 |
+
encoder_output_lengths=mel_input_lengths // 2,
|
406 |
+
max_new_tokens=max_new_tokens,
|
407 |
+
end_id=self.tokenizer.eot,
|
408 |
+
pad_id=self.tokenizer.eot,
|
409 |
+
num_beams=num_beams,
|
410 |
+
output_sequence_lengths=True,
|
411 |
+
return_dict=True)
|
412 |
+
torch.cuda.synchronize()
|
413 |
+
output_ids = outputs['output_ids'].cpu().numpy().tolist()
|
414 |
+
texts = []
|
415 |
+
for i in range(len(output_ids)):
|
416 |
+
text = self.tokenizer.decode(output_ids[i][0]).strip()
|
417 |
+
texts.append(text)
|
418 |
+
return texts
|
419 |
+
|
420 |
+
def transcribe(
|
421 |
+
self,
|
422 |
+
mel,
|
423 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
424 |
+
dtype='float16',
|
425 |
+
batch_size=1,
|
426 |
+
num_beams=1,
|
427 |
+
padding_strategy="max",
|
428 |
+
max_new_tokens=96,
|
429 |
+
):
|
430 |
+
mel = mel.type(str_dtype_to_torch(dtype))
|
431 |
+
mel = mel.unsqueeze(0)
|
432 |
+
# repeat the mel spectrogram to match the batch size
|
433 |
+
mel = mel.repeat(batch_size, 1, 1)
|
434 |
+
if padding_strategy == "longest":
|
435 |
+
pass
|
436 |
+
else:
|
437 |
+
mel = torch.nn.functional.pad(mel, (0, 3000 - mel.shape[2]))
|
438 |
+
features_input_lengths = torch.full((mel.shape[0], ),
|
439 |
+
mel.shape[2],
|
440 |
+
dtype=torch.int32,
|
441 |
+
device=mel.device)
|
442 |
+
|
443 |
+
predictions = self.process_batch(
|
444 |
+
mel,
|
445 |
+
features_input_lengths,
|
446 |
+
text_prefix,
|
447 |
+
num_beams,
|
448 |
+
max_new_tokens=max_new_tokens
|
449 |
+
)
|
450 |
+
prediction = predictions[0]
|
451 |
+
|
452 |
+
# remove all special tokens in the prediction
|
453 |
+
prediction = re.sub(r'<\|.*?\|>', '', prediction)
|
454 |
+
return prediction.strip()
|
455 |
+
|
456 |
+
|
457 |
+
def decode_wav_file(
|
458 |
+
model,
|
459 |
+
mel,
|
460 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
461 |
+
dtype='float16',
|
462 |
+
batch_size=1,
|
463 |
+
num_beams=1,
|
464 |
+
normalizer=None,
|
465 |
+
mel_filters_dir=None):
|
466 |
+
|
467 |
+
mel = mel.type(str_dtype_to_torch(dtype))
|
468 |
+
mel = mel.unsqueeze(0)
|
469 |
+
# repeat the mel spectrogram to match the batch size
|
470 |
+
mel = mel.repeat(batch_size, 1, 1)
|
471 |
+
predictions = model.process_batch(mel, text_prefix, num_beams)
|
472 |
+
prediction = predictions[0]
|
473 |
+
|
474 |
+
# remove all special tokens in the prediction
|
475 |
+
prediction = re.sub(r'<\|.*?\|>', '', prediction)
|
476 |
+
if normalizer:
|
477 |
+
prediction = normalizer(prediction)
|
478 |
+
|
479 |
+
return prediction.strip()
|
whisper_live/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import textwrap
|
3 |
+
import scipy
|
4 |
+
import numpy as np
|
5 |
+
import av
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
def clear_screen():
|
10 |
+
"""Clears the console screen."""
|
11 |
+
os.system("cls" if os.name == "nt" else "clear")
|
12 |
+
|
13 |
+
|
14 |
+
def print_transcript(text):
|
15 |
+
"""Prints formatted transcript text."""
|
16 |
+
wrapper = textwrap.TextWrapper(width=60)
|
17 |
+
for line in wrapper.wrap(text="".join(text)):
|
18 |
+
print(line)
|
19 |
+
|
20 |
+
|
21 |
+
def format_time(s):
|
22 |
+
"""Convert seconds (float) to SRT time format."""
|
23 |
+
hours = int(s // 3600)
|
24 |
+
minutes = int((s % 3600) // 60)
|
25 |
+
seconds = int(s % 60)
|
26 |
+
milliseconds = int((s - int(s)) * 1000)
|
27 |
+
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
|
28 |
+
|
29 |
+
|
30 |
+
def create_srt_file(segments, resampled_file):
|
31 |
+
with open(resampled_file, 'w', encoding='utf-8') as srt_file:
|
32 |
+
segment_number = 1
|
33 |
+
for segment in segments:
|
34 |
+
start_time = format_time(float(segment['start']))
|
35 |
+
end_time = format_time(float(segment['end']))
|
36 |
+
text = segment['text']
|
37 |
+
|
38 |
+
srt_file.write(f"{segment_number}\n")
|
39 |
+
srt_file.write(f"{start_time} --> {end_time}\n")
|
40 |
+
srt_file.write(f"{text}\n\n")
|
41 |
+
|
42 |
+
segment_number += 1
|
43 |
+
|
44 |
+
|
45 |
+
def resample(file: str, sr: int = 16000):
|
46 |
+
"""
|
47 |
+
Resample the audio file to 16kHz.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
file (str): The audio file to open
|
51 |
+
sr (int): The sample rate to resample the audio if necessary
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
resampled_file (str): The resampled audio file
|
55 |
+
"""
|
56 |
+
container = av.open(file)
|
57 |
+
stream = next(s for s in container.streams if s.type == 'audio')
|
58 |
+
|
59 |
+
resampler = av.AudioResampler(
|
60 |
+
format='s16',
|
61 |
+
layout='mono',
|
62 |
+
rate=sr,
|
63 |
+
)
|
64 |
+
|
65 |
+
resampled_file = Path(file).stem + "_resampled.wav"
|
66 |
+
output_container = av.open(resampled_file, mode='w')
|
67 |
+
output_stream = output_container.add_stream('pcm_s16le', rate=sr)
|
68 |
+
output_stream.layout = 'mono'
|
69 |
+
|
70 |
+
for frame in container.decode(audio=0):
|
71 |
+
frame.pts = None
|
72 |
+
resampled_frames = resampler.resample(frame)
|
73 |
+
if resampled_frames is not None:
|
74 |
+
for resampled_frame in resampled_frames:
|
75 |
+
for packet in output_stream.encode(resampled_frame):
|
76 |
+
output_container.mux(packet)
|
77 |
+
|
78 |
+
for packet in output_stream.encode(None):
|
79 |
+
output_container.mux(packet)
|
80 |
+
|
81 |
+
output_container.close()
|
82 |
+
return resampled_file
|
whisper_live/vad.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
|
9 |
+
class VoiceActivityDetection():
|
10 |
+
|
11 |
+
def __init__(self, force_onnx_cpu=True):
|
12 |
+
path = self.download()
|
13 |
+
|
14 |
+
opts = onnxruntime.SessionOptions()
|
15 |
+
opts.log_severity_level = 3
|
16 |
+
|
17 |
+
opts.inter_op_num_threads = 1
|
18 |
+
opts.intra_op_num_threads = 1
|
19 |
+
|
20 |
+
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
21 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
22 |
+
else:
|
23 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
|
24 |
+
|
25 |
+
self.reset_states()
|
26 |
+
if '16k' in path:
|
27 |
+
warnings.warn('This model support only 16000 sampling rate!')
|
28 |
+
self.sample_rates = [16000]
|
29 |
+
else:
|
30 |
+
self.sample_rates = [8000, 16000]
|
31 |
+
|
32 |
+
def _validate_input(self, x, sr: int):
|
33 |
+
if x.dim() == 1:
|
34 |
+
x = x.unsqueeze(0)
|
35 |
+
if x.dim() > 2:
|
36 |
+
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
37 |
+
|
38 |
+
if sr != 16000 and (sr % 16000 == 0):
|
39 |
+
step = sr // 16000
|
40 |
+
x = x[:,::step]
|
41 |
+
sr = 16000
|
42 |
+
|
43 |
+
if sr not in self.sample_rates:
|
44 |
+
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
45 |
+
if sr / x.shape[1] > 31.25:
|
46 |
+
raise ValueError("Input audio chunk is too short")
|
47 |
+
|
48 |
+
return x, sr
|
49 |
+
|
50 |
+
def reset_states(self, batch_size=1):
|
51 |
+
self._state = torch.zeros((2, batch_size, 128)).float()
|
52 |
+
self._context = torch.zeros(0)
|
53 |
+
self._last_sr = 0
|
54 |
+
self._last_batch_size = 0
|
55 |
+
|
56 |
+
def __call__(self, x, sr: int):
|
57 |
+
|
58 |
+
x, sr = self._validate_input(x, sr)
|
59 |
+
num_samples = 512 if sr == 16000 else 256
|
60 |
+
|
61 |
+
if x.shape[-1] != num_samples:
|
62 |
+
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
63 |
+
|
64 |
+
batch_size = x.shape[0]
|
65 |
+
context_size = 64 if sr == 16000 else 32
|
66 |
+
|
67 |
+
if not self._last_batch_size:
|
68 |
+
self.reset_states(batch_size)
|
69 |
+
if (self._last_sr) and (self._last_sr != sr):
|
70 |
+
self.reset_states(batch_size)
|
71 |
+
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
72 |
+
self.reset_states(batch_size)
|
73 |
+
|
74 |
+
if not len(self._context):
|
75 |
+
self._context = torch.zeros(batch_size, context_size)
|
76 |
+
|
77 |
+
x = torch.cat([self._context, x], dim=1)
|
78 |
+
if sr in [8000, 16000]:
|
79 |
+
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
80 |
+
ort_outs = self.session.run(None, ort_inputs)
|
81 |
+
out, state = ort_outs
|
82 |
+
self._state = torch.from_numpy(state)
|
83 |
+
else:
|
84 |
+
raise ValueError()
|
85 |
+
|
86 |
+
self._context = x[..., -context_size:]
|
87 |
+
self._last_sr = sr
|
88 |
+
self._last_batch_size = batch_size
|
89 |
+
|
90 |
+
out = torch.from_numpy(out)
|
91 |
+
return out
|
92 |
+
|
93 |
+
def audio_forward(self, x, sr: int):
|
94 |
+
outs = []
|
95 |
+
x, sr = self._validate_input(x, sr)
|
96 |
+
self.reset_states()
|
97 |
+
num_samples = 512 if sr == 16000 else 256
|
98 |
+
|
99 |
+
if x.shape[1] % num_samples:
|
100 |
+
pad_num = num_samples - (x.shape[1] % num_samples)
|
101 |
+
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
102 |
+
|
103 |
+
for i in range(0, x.shape[1], num_samples):
|
104 |
+
wavs_batch = x[:, i:i+num_samples]
|
105 |
+
out_chunk = self.__call__(wavs_batch, sr)
|
106 |
+
outs.append(out_chunk)
|
107 |
+
|
108 |
+
stacked = torch.cat(outs, dim=1)
|
109 |
+
return stacked.cpu()
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
|
113 |
+
target_dir = os.path.expanduser("~/.cache/whisper-live/")
|
114 |
+
|
115 |
+
# Ensure the target directory exists
|
116 |
+
os.makedirs(target_dir, exist_ok=True)
|
117 |
+
|
118 |
+
# Define the target file path
|
119 |
+
model_filename = os.path.join(target_dir, "silero_vad.onnx")
|
120 |
+
|
121 |
+
# Check if the model file already exists
|
122 |
+
if not os.path.exists(model_filename):
|
123 |
+
# If it doesn't exist, download the model using wget
|
124 |
+
try:
|
125 |
+
subprocess.run(["wget", "-O", model_filename, model_url], check=True)
|
126 |
+
except subprocess.CalledProcessError:
|
127 |
+
print("Failed to download the model using wget.")
|
128 |
+
return model_filename
|
129 |
+
|
130 |
+
|
131 |
+
class VoiceActivityDetector:
|
132 |
+
def __init__(self, threshold=0.5, frame_rate=16000):
|
133 |
+
"""
|
134 |
+
Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
|
138 |
+
"""
|
139 |
+
self.model = VoiceActivityDetection()
|
140 |
+
self.threshold = threshold
|
141 |
+
self.frame_rate = frame_rate
|
142 |
+
|
143 |
+
def __call__(self, audio_frame):
|
144 |
+
"""
|
145 |
+
Determines if the given audio frame contains speech by comparing the detected speech probability against
|
146 |
+
the threshold.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
|
150 |
+
NumPy array of audio samples.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
|
154 |
+
False otherwise.
|
155 |
+
"""
|
156 |
+
speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
|
157 |
+
return torch.any(speech_probs > self.threshold).item()
|