Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- .gitignore +174 -0
- LICENSE +201 -0
- README.md +342 -6
- cli/SparkTTS.py +236 -0
- cli/inference.py +116 -0
- example/infer.sh +47 -0
- example/prompt_audio.wav +3 -0
- requirements.txt +11 -0
- runtime/triton_trtllm/Dockerfile.server +5 -0
- runtime/triton_trtllm/README.md +94 -0
- runtime/triton_trtllm/client_grpc.py +831 -0
- runtime/triton_trtllm/client_http.py +165 -0
- runtime/triton_trtllm/docker-compose.yml +20 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +137 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +58 -0
- runtime/triton_trtllm/model_repo/spark_tts/1/model.py +404 -0
- runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +86 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep +0 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt +857 -0
- runtime/triton_trtllm/model_repo/vocoder/1/model.py +106 -0
- runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +53 -0
- runtime/triton_trtllm/run.sh +109 -0
- runtime/triton_trtllm/scripts/convert_checkpoint.py +335 -0
- runtime/triton_trtllm/scripts/fill_template.py +70 -0
- sparktts/models/audio_tokenizer.py +163 -0
- sparktts/models/bicodec.py +247 -0
- sparktts/modules/blocks/layers.py +73 -0
- sparktts/modules/blocks/samper.py +115 -0
- sparktts/modules/blocks/vocos.py +373 -0
- sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
- sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
- sparktts/modules/encoder_decoder/wave_generator.py +88 -0
- sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
- sparktts/modules/fsq/residual_fsq.py +355 -0
- sparktts/modules/speaker/ecapa_tdnn.py +267 -0
- sparktts/modules/speaker/perceiver_encoder.py +360 -0
- sparktts/modules/speaker/pooling_layers.py +298 -0
- sparktts/modules/speaker/speaker_encoder.py +136 -0
- sparktts/modules/vq/factorized_vector_quantize.py +187 -0
- sparktts/utils/__init__.py +0 -0
- sparktts/utils/audio.py +271 -0
- sparktts/utils/file.py +221 -0
- sparktts/utils/parse_options.sh +97 -0
- sparktts/utils/token_parser.py +187 -0
- src/demos/trump/trump_en.wav +3 -0
- src/demos/zhongli/zhongli_en.wav +3 -0
- src/demos/余承东/yuchengdong_zh.wav +3 -0
- src/demos/刘德华/dehua_zh.wav +3 -0
- src/demos/哪吒/nezha_zh.wav +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example/prompt_audio.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
src/demos/trump/trump_en.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
src/demos/zhongli/zhongli_en.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/demos/余承东/yuchengdong_zh.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
src/demos/刘德华/dehua_zh.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
src/demos/哪吒/nezha_zh.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
src/demos/徐志胜/zhisheng_zh.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
src/demos/李靖/lijing_zh.wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
src/demos/杨澜/yanglan_zh.wav filter=lfs diff=lfs merge=lfs -text
|
45 |
+
src/demos/马云/mayun_zh.wav filter=lfs diff=lfs merge=lfs -text
|
46 |
+
src/demos/鲁豫/luyu_zh.wav filter=lfs diff=lfs merge=lfs -text
|
47 |
+
src/figures/infer_control.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
src/figures/infer_voice_cloning.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
src/logo/mobvoi.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
src/logo/SparkTTS.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
pretrained_models/
|
6 |
+
results/
|
7 |
+
demo/
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
.gradio/
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
webui_test.py
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# UV
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
#uv.lock
|
105 |
+
|
106 |
+
# poetry
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
109 |
+
# commonly ignored for libraries.
|
110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
111 |
+
#poetry.lock
|
112 |
+
|
113 |
+
# pdm
|
114 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
115 |
+
#pdm.lock
|
116 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
117 |
+
# in version control.
|
118 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
119 |
+
.pdm.toml
|
120 |
+
.pdm-python
|
121 |
+
.pdm-build/
|
122 |
+
|
123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
124 |
+
__pypackages__/
|
125 |
+
|
126 |
+
# Celery stuff
|
127 |
+
celerybeat-schedule
|
128 |
+
celerybeat.pid
|
129 |
+
|
130 |
+
# SageMath parsed files
|
131 |
+
*.sage.py
|
132 |
+
|
133 |
+
# Environments
|
134 |
+
.env
|
135 |
+
.venv
|
136 |
+
env/
|
137 |
+
venv/
|
138 |
+
ENV/
|
139 |
+
env.bak/
|
140 |
+
venv.bak/
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
#.idea/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,348 @@
|
|
1 |
---
|
2 |
title: SPKTTS
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: SPKTTS
|
3 |
+
app_file: webui.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.25.2
|
|
|
|
|
6 |
---
|
7 |
+
<div align="center">
|
8 |
+
<h1>
|
9 |
+
Spark-TTS
|
10 |
+
</h1>
|
11 |
+
<p>
|
12 |
+
Official PyTorch code for inference of <br>
|
13 |
+
<b><em>Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens</em></b>
|
14 |
+
</p>
|
15 |
+
<p>
|
16 |
+
<img src="src/logo/SparkTTS.jpg" alt="Spark-TTS Logo" style="width: 200px; height: 200px;">
|
17 |
+
</p>
|
18 |
+
<p>
|
19 |
+
<img src="src/logo/HKUST.jpg" alt="Institution 1" style="width: 200px; height: 60px;">
|
20 |
+
<img src="src/logo/mobvoi.jpg" alt="Institution 2" style="width: 200px; height: 60px;">
|
21 |
+
<img src="src/logo/SJU.jpg" alt="Institution 3" style="width: 200px; height: 60px;">
|
22 |
+
</p>
|
23 |
+
<p>
|
24 |
+
<img src="src/logo/NTU.jpg" alt="Institution 4" style="width: 200px; height: 60px;">
|
25 |
+
<img src="src/logo/NPU.jpg" alt="Institution 5" style="width: 200px; height: 60px;">
|
26 |
+
<img src="src/logo/SparkAudio2.jpg" alt="Institution 6" style="width: 200px; height: 60px;">
|
27 |
+
</p>
|
28 |
+
<p>
|
29 |
+
</p>
|
30 |
+
<a href="https://arxiv.org/pdf/2503.01710"><img src="https://img.shields.io/badge/Paper-ArXiv-red" alt="paper"></a>
|
31 |
+
<a href="https://sparkaudio.github.io/spark-tts/"><img src="https://img.shields.io/badge/Demo-Page-lightgrey" alt="version"></a>
|
32 |
+
<a href="https://huggingface.co/SparkAudio/Spark-TTS-0.5B"><img src="https://img.shields.io/badge/Hugging%20Face-Model%20Page-yellow" alt="Hugging Face"></a>
|
33 |
+
<a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Platform-linux-lightgrey" alt="version"></a>
|
34 |
+
<a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Python-3.12+-orange" alt="version"></a>
|
35 |
+
<a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/PyTorch-2.5+-brightgreen" alt="python"></a>
|
36 |
+
<a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="mit"></a>
|
37 |
+
</div>
|
38 |
|
39 |
+
|
40 |
+
## Spark-TTS 🔥
|
41 |
+
|
42 |
+
### Overview
|
43 |
+
|
44 |
+
Spark-TTS is an advanced text-to-speech system that uses the power of large language models (LLM) for highly accurate and natural-sounding voice synthesis. It is designed to be efficient, flexible, and powerful for both research and production use.
|
45 |
+
|
46 |
+
### Key Features
|
47 |
+
|
48 |
+
- **Simplicity and Efficiency**: Built entirely on Qwen2.5, Spark-TTS eliminates the need for additional generation models like flow matching. Instead of relying on separate models to generate acoustic features, it directly reconstructs audio from the code predicted by the LLM. This approach streamlines the process, improving efficiency and reducing complexity.
|
49 |
+
- **High-Quality Voice Cloning**: Supports zero-shot voice cloning, which means it can replicate a speaker's voice even without specific training data for that voice. This is ideal for cross-lingual and code-switching scenarios, allowing for seamless transitions between languages and voices without requiring separate training for each one.
|
50 |
+
- **Bilingual Support**: Supports both Chinese and English, and is capable of zero-shot voice cloning for cross-lingual and code-switching scenarios, enabling the model to synthesize speech in multiple languages with high naturalness and accuracy.
|
51 |
+
- **Controllable Speech Generation**: Supports creating virtual speakers by adjusting parameters such as gender, pitch, and speaking rate.
|
52 |
+
|
53 |
+
---
|
54 |
+
|
55 |
+
<table align="center">
|
56 |
+
<tr>
|
57 |
+
<td align="center"><b>Inference Overview of Voice Cloning</b><br><img src="src/figures/infer_voice_cloning.png" width="80%" /></td>
|
58 |
+
</tr>
|
59 |
+
<tr>
|
60 |
+
<td align="center"><b>Inference Overview of Controlled Generation</b><br><img src="src/figures/infer_control.png" width="80%" /></td>
|
61 |
+
</tr>
|
62 |
+
</table>
|
63 |
+
|
64 |
+
|
65 |
+
## 🚀 News
|
66 |
+
|
67 |
+
- **[2025-03-04]** Our paper on this project has been published! You can read it here: [Spark-TTS](https://arxiv.org/pdf/2503.01710).
|
68 |
+
|
69 |
+
- **[2025-03-12]** Nvidia Triton Inference Serving is now supported. See the Runtime section below for more details.
|
70 |
+
|
71 |
+
|
72 |
+
## Install
|
73 |
+
**Clone and Install**
|
74 |
+
|
75 |
+
Here are instructions for installing on Linux. If you're on Windows, please refer to the [Windows Installation Guide](https://github.com/SparkAudio/Spark-TTS/issues/5).
|
76 |
+
*(Thanks to [@AcTePuKc](https://github.com/AcTePuKc) for the detailed Windows instructions!)*
|
77 |
+
|
78 |
+
|
79 |
+
- Clone the repo
|
80 |
+
``` sh
|
81 |
+
git clone https://github.com/SparkAudio/Spark-TTS.git
|
82 |
+
cd Spark-TTS
|
83 |
+
```
|
84 |
+
|
85 |
+
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
86 |
+
- Create Conda env:
|
87 |
+
|
88 |
+
``` sh
|
89 |
+
conda create -n sparktts -y python=3.12
|
90 |
+
conda activate sparktts
|
91 |
+
pip install -r requirements.txt
|
92 |
+
# If you are in mainland China, you can set the mirror as follows:
|
93 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
94 |
+
```
|
95 |
+
|
96 |
+
**Model Download**
|
97 |
+
|
98 |
+
Download via python:
|
99 |
+
```python
|
100 |
+
from huggingface_hub import snapshot_download
|
101 |
+
|
102 |
+
snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
|
103 |
+
```
|
104 |
+
|
105 |
+
Download via git clone:
|
106 |
+
```sh
|
107 |
+
mkdir -p pretrained_models
|
108 |
+
|
109 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
110 |
+
git lfs install
|
111 |
+
|
112 |
+
git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
|
113 |
+
```
|
114 |
+
|
115 |
+
**Basic Usage**
|
116 |
+
|
117 |
+
You can simply run the demo with the following commands:
|
118 |
+
``` sh
|
119 |
+
cd example
|
120 |
+
bash infer.sh
|
121 |
+
```
|
122 |
+
|
123 |
+
Alternatively, you can directly execute the following command in the command line to perform inference:
|
124 |
+
|
125 |
+
``` sh
|
126 |
+
python -m cli.inference \
|
127 |
+
--text "text to synthesis." \
|
128 |
+
--device 0 \
|
129 |
+
--save_dir "path/to/save/audio" \
|
130 |
+
--model_dir pretrained_models/Spark-TTS-0.5B \
|
131 |
+
--prompt_text "transcript of the prompt audio" \
|
132 |
+
--prompt_speech_path "path/to/prompt_audio"
|
133 |
+
```
|
134 |
+
|
135 |
+
**Web UI Usage**
|
136 |
+
|
137 |
+
You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
|
138 |
+
|
139 |
+
|
140 |
+
| **Voice Cloning** | **Voice Creation** |
|
141 |
+
|:-------------------:|:-------------------:|
|
142 |
+
|  |  |
|
143 |
+
|
144 |
+
|
145 |
+
**Optional Methods**
|
146 |
+
|
147 |
+
For additional CLI and Web UI methods, including alternative implementations and extended functionalities, you can refer to:
|
148 |
+
|
149 |
+
- [CLI and UI by AcTePuKc](https://github.com/SparkAudio/Spark-TTS/issues/10)
|
150 |
+
|
151 |
+
|
152 |
+
## Runtime
|
153 |
+
|
154 |
+
**Nvidia Triton Inference Serving**
|
155 |
+
|
156 |
+
We now provide a reference for deploying Spark-TTS with Nvidia Triton and TensorRT-LLM. The table below presents benchmark results on a single L20 GPU, using 26 different prompt_audio/target_text pairs (totalling 169 seconds of audio):
|
157 |
+
|
158 |
+
| Model | Note | Concurrency | Avg Latency | RTF |
|
159 |
+
|-------|-----------|-----------------------|---------|--|
|
160 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362|
|
161 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737|
|
162 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704|
|
163 |
+
|
164 |
+
|
165 |
+
Please see the detailed instructions in [runtime/triton_trtllm/README.md](runtime/triton_trtllm/README.md ) for more information.
|
166 |
+
|
167 |
+
|
168 |
+
## **Demos**
|
169 |
+
|
170 |
+
Here are some demos generated by Spark-TTS using zero-shot voice cloning. For more demos, visit our [demo page](https://sparkaudio.github.io/spark-tts/).
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
<table>
|
175 |
+
<tr>
|
176 |
+
<td align="center">
|
177 |
+
|
178 |
+
**Donald Trump**
|
179 |
+
</td>
|
180 |
+
<td align="center">
|
181 |
+
|
182 |
+
**Zhongli (Genshin Impact)**
|
183 |
+
</td>
|
184 |
+
</tr>
|
185 |
+
|
186 |
+
<tr>
|
187 |
+
<td align="center">
|
188 |
+
|
189 |
+
[Donald Trump](https://github.com/user-attachments/assets/fb225780-d9fe-44b2-9b2e-54390cb3d8fd)
|
190 |
+
|
191 |
+
</td>
|
192 |
+
<td align="center">
|
193 |
+
|
194 |
+
[Zhongli](https://github.com/user-attachments/assets/80eeb9c7-0443-4758-a1ce-55ac59e64bd6)
|
195 |
+
|
196 |
+
</td>
|
197 |
+
</tr>
|
198 |
+
</table>
|
199 |
+
|
200 |
+
---
|
201 |
+
|
202 |
+
<table>
|
203 |
+
|
204 |
+
<tr>
|
205 |
+
<td align="center">
|
206 |
+
|
207 |
+
**陈鲁豫 Chen Luyu**
|
208 |
+
</td>
|
209 |
+
<td align="center">
|
210 |
+
|
211 |
+
**杨澜 Yang Lan**
|
212 |
+
</td>
|
213 |
+
</tr>
|
214 |
+
|
215 |
+
<tr>
|
216 |
+
<td align="center">
|
217 |
+
|
218 |
+
[陈鲁豫Chen_Luyu.webm](https://github.com/user-attachments/assets/5c6585ae-830d-47b1-992d-ee3691f48cf4)
|
219 |
+
</td>
|
220 |
+
<td align="center">
|
221 |
+
|
222 |
+
[Yang_Lan.webm](https://github.com/user-attachments/assets/2fb3d00c-abc3-410e-932f-46ba204fb1d7)
|
223 |
+
</td>
|
224 |
+
</tr>
|
225 |
+
</table>
|
226 |
+
|
227 |
+
---
|
228 |
+
|
229 |
+
|
230 |
+
<table>
|
231 |
+
<tr>
|
232 |
+
<td align="center">
|
233 |
+
|
234 |
+
**余承东 Richard Yu**
|
235 |
+
</td>
|
236 |
+
<td align="center">
|
237 |
+
|
238 |
+
**马云 Jack Ma**
|
239 |
+
</td>
|
240 |
+
</tr>
|
241 |
+
|
242 |
+
<tr>
|
243 |
+
<td align="center">
|
244 |
+
|
245 |
+
[Yu_Chengdong.webm](https://github.com/user-attachments/assets/78feca02-84bb-4d3a-a770-0cfd02f1a8da)
|
246 |
+
|
247 |
+
</td>
|
248 |
+
<td align="center">
|
249 |
+
|
250 |
+
[Ma_Yun.webm](https://github.com/user-attachments/assets/2d54e2eb-cec4-4c2f-8c84-8fe587da321b)
|
251 |
+
|
252 |
+
</td>
|
253 |
+
</tr>
|
254 |
+
</table>
|
255 |
+
|
256 |
+
---
|
257 |
+
|
258 |
+
|
259 |
+
<table>
|
260 |
+
<tr>
|
261 |
+
<td align="center">
|
262 |
+
|
263 |
+
**刘德华 Andy Lau**
|
264 |
+
</td>
|
265 |
+
<td align="center">
|
266 |
+
|
267 |
+
**徐志胜 Xu Zhisheng**
|
268 |
+
</td>
|
269 |
+
</tr>
|
270 |
+
|
271 |
+
<tr>
|
272 |
+
<td align="center">
|
273 |
+
|
274 |
+
[Liu_Dehua.webm](https://github.com/user-attachments/assets/195b5e97-1fee-4955-b954-6d10fa04f1d7)
|
275 |
+
|
276 |
+
</td>
|
277 |
+
<td align="center">
|
278 |
+
|
279 |
+
[Xu_Zhisheng.webm](https://github.com/user-attachments/assets/dd812af9-76bd-4e26-9988-9cdb9ccbb87b)
|
280 |
+
|
281 |
+
</td>
|
282 |
+
</tr>
|
283 |
+
</table>
|
284 |
+
|
285 |
+
|
286 |
+
---
|
287 |
+
|
288 |
+
<table>
|
289 |
+
<tr>
|
290 |
+
<td align="center">
|
291 |
+
|
292 |
+
**哪吒 Nezha**
|
293 |
+
</td>
|
294 |
+
<td align="center">
|
295 |
+
|
296 |
+
**李靖 Li Jing**
|
297 |
+
</td>
|
298 |
+
</tr>
|
299 |
+
|
300 |
+
<tr>
|
301 |
+
<td align="center">
|
302 |
+
|
303 |
+
[Ne_Zha.webm](https://github.com/user-attachments/assets/8c608037-a17a-46d4-8588-4db34b49ed1d)
|
304 |
+
</td>
|
305 |
+
<td align="center">
|
306 |
+
|
307 |
+
[Li_Jing.webm](https://github.com/user-attachments/assets/aa8ba091-097c-4156-b4e3-6445da5ea101)
|
308 |
+
|
309 |
+
</td>
|
310 |
+
</tr>
|
311 |
+
</table>
|
312 |
+
|
313 |
+
|
314 |
+
## To-Do List
|
315 |
+
|
316 |
+
- [x] Release the Spark-TTS paper.
|
317 |
+
- [ ] Release the training code.
|
318 |
+
- [ ] Release the training dataset, VoxBox.
|
319 |
+
|
320 |
+
|
321 |
+
## Citation
|
322 |
+
|
323 |
+
```
|
324 |
+
@misc{wang2025sparktts,
|
325 |
+
title={Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens},
|
326 |
+
author={Xinsheng Wang and Mingqi Jiang and Ziyang Ma and Ziyu Zhang and Songxiang Liu and Linqin Li and Zheng Liang and Qixi Zheng and Rui Wang and Xiaoqin Feng and Weizhen Bian and Zhen Ye and Sitong Cheng and Ruibin Yuan and Zhixian Zhao and Xinfa Zhu and Jiahao Pan and Liumeng Xue and Pengcheng Zhu and Yunlin Chen and Zhifei Li and Xie Chen and Lei Xie and Yike Guo and Wei Xue},
|
327 |
+
year={2025},
|
328 |
+
eprint={2503.01710},
|
329 |
+
archivePrefix={arXiv},
|
330 |
+
primaryClass={cs.SD},
|
331 |
+
url={https://arxiv.org/abs/2503.01710},
|
332 |
+
}
|
333 |
+
```
|
334 |
+
|
335 |
+
|
336 |
+
## ⚠️ Usage Disclaimer
|
337 |
+
|
338 |
+
This project provides a zero-shot voice cloning TTS model intended for academic research, educational purposes, and legitimate applications, such as personalized speech synthesis, assistive technologies, and linguistic research.
|
339 |
+
|
340 |
+
Please note:
|
341 |
+
|
342 |
+
- Do not use this model for unauthorized voice cloning, impersonation, fraud, scams, deepfakes, or any illegal activities.
|
343 |
+
|
344 |
+
- Ensure compliance with local laws and regulations when using this model and uphold ethical standards.
|
345 |
+
|
346 |
+
- The developers assume no liability for any misuse of this model.
|
347 |
+
|
348 |
+
We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles in AI research and applications. If you have any concerns regarding ethics or misuse, please contact us.
|
cli/SparkTTS.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
import torch
|
18 |
+
from typing import Tuple
|
19 |
+
from pathlib import Path
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
+
|
22 |
+
from sparktts.utils.file import load_config
|
23 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
24 |
+
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
|
25 |
+
|
26 |
+
|
27 |
+
class SparkTTS:
|
28 |
+
"""
|
29 |
+
Spark-TTS for text-to-speech generation.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
|
33 |
+
"""
|
34 |
+
Initializes the SparkTTS model with the provided configurations and device.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_dir (Path): Directory containing the model and config files.
|
38 |
+
device (torch.device): The device (CPU/GPU) to run the model on.
|
39 |
+
"""
|
40 |
+
self.device = device
|
41 |
+
self.model_dir = model_dir
|
42 |
+
self.configs = load_config(f"{model_dir}/config.yaml")
|
43 |
+
self.sample_rate = self.configs["sample_rate"]
|
44 |
+
self._initialize_inference()
|
45 |
+
|
46 |
+
def _initialize_inference(self):
|
47 |
+
"""Initializes the tokenizer, model, and audio tokenizer for inference."""
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
|
49 |
+
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
50 |
+
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
51 |
+
self.model.to(self.device)
|
52 |
+
|
53 |
+
def process_prompt(
|
54 |
+
self,
|
55 |
+
text: str,
|
56 |
+
prompt_speech_path: Path,
|
57 |
+
prompt_text: str = None,
|
58 |
+
) -> Tuple[str, torch.Tensor]:
|
59 |
+
"""
|
60 |
+
Process input for voice cloning.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
text (str): The text input to be converted to speech.
|
64 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
65 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
66 |
+
|
67 |
+
Return:
|
68 |
+
Tuple[str, torch.Tensor]: Input prompt; global tokens
|
69 |
+
"""
|
70 |
+
|
71 |
+
global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
|
72 |
+
prompt_speech_path
|
73 |
+
)
|
74 |
+
global_tokens = "".join(
|
75 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
76 |
+
)
|
77 |
+
|
78 |
+
# Prepare the input tokens for the model
|
79 |
+
if prompt_text is not None:
|
80 |
+
semantic_tokens = "".join(
|
81 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
82 |
+
)
|
83 |
+
inputs = [
|
84 |
+
TASK_TOKEN_MAP["tts"],
|
85 |
+
"<|start_content|>",
|
86 |
+
prompt_text,
|
87 |
+
text,
|
88 |
+
"<|end_content|>",
|
89 |
+
"<|start_global_token|>",
|
90 |
+
global_tokens,
|
91 |
+
"<|end_global_token|>",
|
92 |
+
"<|start_semantic_token|>",
|
93 |
+
semantic_tokens,
|
94 |
+
]
|
95 |
+
else:
|
96 |
+
inputs = [
|
97 |
+
TASK_TOKEN_MAP["tts"],
|
98 |
+
"<|start_content|>",
|
99 |
+
text,
|
100 |
+
"<|end_content|>",
|
101 |
+
"<|start_global_token|>",
|
102 |
+
global_tokens,
|
103 |
+
"<|end_global_token|>",
|
104 |
+
]
|
105 |
+
|
106 |
+
inputs = "".join(inputs)
|
107 |
+
|
108 |
+
return inputs, global_token_ids
|
109 |
+
|
110 |
+
def process_prompt_control(
|
111 |
+
self,
|
112 |
+
gender: str,
|
113 |
+
pitch: str,
|
114 |
+
speed: str,
|
115 |
+
text: str,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Process input for voice creation.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
gender (str): female | male.
|
122 |
+
pitch (str): very_low | low | moderate | high | very_high
|
123 |
+
speed (str): very_low | low | moderate | high | very_high
|
124 |
+
text (str): The text input to be converted to speech.
|
125 |
+
|
126 |
+
Return:
|
127 |
+
str: Input prompt
|
128 |
+
"""
|
129 |
+
assert gender in GENDER_MAP.keys()
|
130 |
+
assert pitch in LEVELS_MAP.keys()
|
131 |
+
assert speed in LEVELS_MAP.keys()
|
132 |
+
|
133 |
+
gender_id = GENDER_MAP[gender]
|
134 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
135 |
+
speed_level_id = LEVELS_MAP[speed]
|
136 |
+
|
137 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
138 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
139 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
140 |
+
|
141 |
+
attribte_tokens = "".join(
|
142 |
+
[gender_tokens, pitch_label_tokens, speed_label_tokens]
|
143 |
+
)
|
144 |
+
|
145 |
+
control_tts_inputs = [
|
146 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
147 |
+
"<|start_content|>",
|
148 |
+
text,
|
149 |
+
"<|end_content|>",
|
150 |
+
"<|start_style_label|>",
|
151 |
+
attribte_tokens,
|
152 |
+
"<|end_style_label|>",
|
153 |
+
]
|
154 |
+
|
155 |
+
return "".join(control_tts_inputs)
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def inference(
|
159 |
+
self,
|
160 |
+
text: str,
|
161 |
+
prompt_speech_path: Path = None,
|
162 |
+
prompt_text: str = None,
|
163 |
+
gender: str = None,
|
164 |
+
pitch: str = None,
|
165 |
+
speed: str = None,
|
166 |
+
temperature: float = 0.8,
|
167 |
+
top_k: float = 50,
|
168 |
+
top_p: float = 0.95,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
"""
|
171 |
+
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
text (str): The text input to be converted to speech.
|
175 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
176 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
177 |
+
gender (str): female | male.
|
178 |
+
pitch (str): very_low | low | moderate | high | very_high
|
179 |
+
speed (str): very_low | low | moderate | high | very_high
|
180 |
+
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
181 |
+
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
182 |
+
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
torch.Tensor: Generated waveform as a tensor.
|
186 |
+
"""
|
187 |
+
if gender is not None:
|
188 |
+
prompt = self.process_prompt_control(gender, pitch, speed, text)
|
189 |
+
|
190 |
+
else:
|
191 |
+
prompt, global_token_ids = self.process_prompt(
|
192 |
+
text, prompt_speech_path, prompt_text
|
193 |
+
)
|
194 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
195 |
+
|
196 |
+
# Generate speech using the model
|
197 |
+
generated_ids = self.model.generate(
|
198 |
+
**model_inputs,
|
199 |
+
max_new_tokens=3000,
|
200 |
+
do_sample=True,
|
201 |
+
top_k=top_k,
|
202 |
+
top_p=top_p,
|
203 |
+
temperature=temperature,
|
204 |
+
)
|
205 |
+
|
206 |
+
# Trim the output tokens to remove the input tokens
|
207 |
+
generated_ids = [
|
208 |
+
output_ids[len(input_ids) :]
|
209 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
210 |
+
]
|
211 |
+
|
212 |
+
# Decode the generated tokens into text
|
213 |
+
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
214 |
+
|
215 |
+
# Extract semantic token IDs from the generated text
|
216 |
+
pred_semantic_ids = (
|
217 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
|
218 |
+
.long()
|
219 |
+
.unsqueeze(0)
|
220 |
+
)
|
221 |
+
|
222 |
+
if gender is not None:
|
223 |
+
global_token_ids = (
|
224 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
|
225 |
+
.long()
|
226 |
+
.unsqueeze(0)
|
227 |
+
.unsqueeze(0)
|
228 |
+
)
|
229 |
+
|
230 |
+
# Convert semantic tokens back to waveform
|
231 |
+
wav = self.audio_tokenizer.detokenize(
|
232 |
+
global_token_ids.to(self.device).squeeze(0),
|
233 |
+
pred_semantic_ids.to(self.device),
|
234 |
+
)
|
235 |
+
|
236 |
+
return wav
|
cli/inference.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import argparse
|
19 |
+
import torch
|
20 |
+
import soundfile as sf
|
21 |
+
import logging
|
22 |
+
from datetime import datetime
|
23 |
+
import platform
|
24 |
+
|
25 |
+
from cli.SparkTTS import SparkTTS
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
"""Parse command-line arguments."""
|
30 |
+
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--model_dir",
|
34 |
+
type=str,
|
35 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
36 |
+
help="Path to the model directory",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--save_dir",
|
40 |
+
type=str,
|
41 |
+
default="example/results",
|
42 |
+
help="Directory to save generated audio files",
|
43 |
+
)
|
44 |
+
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
45 |
+
parser.add_argument(
|
46 |
+
"--text", type=str, required=True, help="Text for TTS generation"
|
47 |
+
)
|
48 |
+
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
49 |
+
parser.add_argument(
|
50 |
+
"--prompt_speech_path",
|
51 |
+
type=str,
|
52 |
+
help="Path to the prompt audio file",
|
53 |
+
)
|
54 |
+
parser.add_argument("--gender", choices=["male", "female"])
|
55 |
+
parser.add_argument(
|
56 |
+
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
|
60 |
+
)
|
61 |
+
return parser.parse_args()
|
62 |
+
|
63 |
+
|
64 |
+
def run_tts(args):
|
65 |
+
"""Perform TTS inference and save the generated audio."""
|
66 |
+
logging.info(f"Using model from: {args.model_dir}")
|
67 |
+
logging.info(f"Saving audio to: {args.save_dir}")
|
68 |
+
|
69 |
+
# Ensure the save directory exists
|
70 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
71 |
+
|
72 |
+
# Convert device argument to torch.device
|
73 |
+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
|
74 |
+
# macOS with MPS support (Apple Silicon)
|
75 |
+
device = torch.device(f"mps:{args.device}")
|
76 |
+
logging.info(f"Using MPS device: {device}")
|
77 |
+
elif torch.cuda.is_available():
|
78 |
+
# System with CUDA support
|
79 |
+
device = torch.device(f"cuda:{args.device}")
|
80 |
+
logging.info(f"Using CUDA device: {device}")
|
81 |
+
else:
|
82 |
+
# Fall back to CPU
|
83 |
+
device = torch.device("cpu")
|
84 |
+
logging.info("GPU acceleration not available, using CPU")
|
85 |
+
|
86 |
+
# Initialize the model
|
87 |
+
model = SparkTTS(args.model_dir, device)
|
88 |
+
|
89 |
+
# Generate unique filename using timestamp
|
90 |
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
91 |
+
save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
|
92 |
+
|
93 |
+
logging.info("Starting inference...")
|
94 |
+
|
95 |
+
# Perform inference and save the output audio
|
96 |
+
with torch.no_grad():
|
97 |
+
wav = model.inference(
|
98 |
+
args.text,
|
99 |
+
args.prompt_speech_path,
|
100 |
+
prompt_text=args.prompt_text,
|
101 |
+
gender=args.gender,
|
102 |
+
pitch=args.pitch,
|
103 |
+
speed=args.speed,
|
104 |
+
)
|
105 |
+
sf.write(save_path, wav, samplerate=16000)
|
106 |
+
|
107 |
+
logging.info(f"Audio saved at: {save_path}")
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
logging.basicConfig(
|
112 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
113 |
+
)
|
114 |
+
|
115 |
+
args = parse_args()
|
116 |
+
run_tts(args)
|
example/infer.sh
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright (c) 2025 SparkAudio
|
4 |
+
# 2025 Xinsheng Wang ([email protected])
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
# Get the absolute path of the script's directory
|
20 |
+
script_dir=$(dirname "$(realpath "$0")")
|
21 |
+
|
22 |
+
# Get the root directory
|
23 |
+
root_dir=$(dirname "$script_dir")
|
24 |
+
|
25 |
+
# Set default parameters
|
26 |
+
device=0
|
27 |
+
save_dir='example/results'
|
28 |
+
model_dir="pretrained_models/Spark-TTS-0.5B"
|
29 |
+
text="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。"
|
30 |
+
prompt_text="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。"
|
31 |
+
prompt_speech_path="example/prompt_audio.wav"
|
32 |
+
|
33 |
+
# Change directory to the root directory
|
34 |
+
cd "$root_dir" || exit
|
35 |
+
|
36 |
+
source sparktts/utils/parse_options.sh
|
37 |
+
|
38 |
+
# Run inference
|
39 |
+
python -m cli.inference \
|
40 |
+
--text "${text}" \
|
41 |
+
--device "${device}" \
|
42 |
+
--save_dir "${save_dir}" \
|
43 |
+
--model_dir "${model_dir}" \
|
44 |
+
--prompt_text "${prompt_text}" \
|
45 |
+
--prompt_speech_path "${prompt_speech_path}"
|
46 |
+
|
47 |
+
|
example/prompt_audio.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:335e7f7789b231cd90d9670292d561ecfe6a6bdd5e737a7bc6c29730741852de
|
3 |
+
size 318550
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.8.1
|
2 |
+
einx==0.3.0
|
3 |
+
numpy==2.2.3
|
4 |
+
omegaconf==2.3.0
|
5 |
+
packaging==24.2
|
6 |
+
safetensors==0.5.2
|
7 |
+
soundfile==0.12.1
|
8 |
+
soxr==0.5.0.post1
|
9 |
+
tqdm==4.66.5
|
10 |
+
transformers==4.46.2
|
11 |
+
gradio==5.18.0
|
runtime/triton_trtllm/Dockerfile.server
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
|
2 |
+
RUN apt-get update && apt-get install -y cmake
|
3 |
+
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
4 |
+
RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
|
5 |
+
WORKDIR /workspace
|
runtime/triton_trtllm/README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Nvidia Triton Inference Serving Best Practice for Spark TTS
|
2 |
+
|
3 |
+
### Quick Start
|
4 |
+
Directly launch the service using docker compose.
|
5 |
+
```sh
|
6 |
+
docker compose up
|
7 |
+
```
|
8 |
+
|
9 |
+
### Build Image
|
10 |
+
Build the docker image from scratch.
|
11 |
+
```sh
|
12 |
+
docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
|
13 |
+
```
|
14 |
+
|
15 |
+
### Create Docker Container
|
16 |
+
```sh
|
17 |
+
your_mount_dir=/mnt:/mnt
|
18 |
+
docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
|
19 |
+
```
|
20 |
+
|
21 |
+
### Understanding `run.sh`
|
22 |
+
|
23 |
+
The `run.sh` script automates various steps using stages. You can run specific stages using:
|
24 |
+
```sh
|
25 |
+
bash run.sh <start_stage> <stop_stage> [service_type]
|
26 |
+
```
|
27 |
+
- `<start_stage>`: The stage to begin execution from (0-5).
|
28 |
+
- `<stop_stage>`: The stage to end execution at (0-5).
|
29 |
+
- `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5.
|
30 |
+
|
31 |
+
Stages:
|
32 |
+
- **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace.
|
33 |
+
- **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
|
34 |
+
- **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline).
|
35 |
+
- **Stage 3**: Launch the Triton Inference Server.
|
36 |
+
- **Stage 4**: Run the gRPC benchmark client.
|
37 |
+
- **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline).
|
38 |
+
|
39 |
+
### Export Models to TensorRT-LLM and Launch Server
|
40 |
+
Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server.
|
41 |
+
```sh
|
42 |
+
# This runs stages 0, 1, 2, and 3
|
43 |
+
bash run.sh 0 3
|
44 |
+
```
|
45 |
+
*Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.*
|
46 |
+
|
47 |
+
|
48 |
+
### Single Utterance Client
|
49 |
+
Run a single inference request. Specify `streaming` or `offline` as the third argument.
|
50 |
+
|
51 |
+
**Streaming Mode (gRPC):**
|
52 |
+
```sh
|
53 |
+
bash run.sh 5 5 streaming
|
54 |
+
```
|
55 |
+
This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode.
|
56 |
+
|
57 |
+
**Offline Mode (HTTP):**
|
58 |
+
```sh
|
59 |
+
bash run.sh 5 5 offline
|
60 |
+
```
|
61 |
+
|
62 |
+
### Benchmark using Dataset
|
63 |
+
Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument.
|
64 |
+
```sh
|
65 |
+
# Run benchmark in streaming mode
|
66 |
+
bash run.sh 4 4 streaming
|
67 |
+
|
68 |
+
# Run benchmark in offline mode
|
69 |
+
bash run.sh 4 4 offline
|
70 |
+
|
71 |
+
# You can also customize parameters like num_task directly in client_grpc.py or via args if supported
|
72 |
+
# Example from run.sh (streaming):
|
73 |
+
# python3 client_grpc.py \
|
74 |
+
# --server-addr localhost \
|
75 |
+
# --model-name spark_tts \
|
76 |
+
# --num-tasks 2 \
|
77 |
+
# --mode streaming \
|
78 |
+
# --log-dir ./log_concurrent_tasks_2_streaming_new
|
79 |
+
|
80 |
+
# Example customizing dataset (requires modifying client_grpc.py or adding args):
|
81 |
+
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline]
|
82 |
+
```
|
83 |
+
|
84 |
+
### Benchmark Results
|
85 |
+
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs.
|
86 |
+
|
87 |
+
| Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF |
|
88 |
+
|-------|-----------|-----------------------|---------|----------------|-|
|
89 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362|
|
90 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737|
|
91 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704|
|
92 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 |
|
93 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 |
|
94 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 |
|
runtime/triton_trtllm/client_grpc.py
ADDED
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
# 2023 Nvidia (authors: Yuekai Zhang)
|
4 |
+
# 2023 Recurrent.ai (authors: Songtao Shi)
|
5 |
+
# See LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
"""
|
19 |
+
This script supports to load dataset from huggingface and sends it to the server
|
20 |
+
for decoding, in parallel.
|
21 |
+
|
22 |
+
Usage:
|
23 |
+
num_task=2
|
24 |
+
|
25 |
+
# For offline F5-TTS
|
26 |
+
python3 client_grpc.py \
|
27 |
+
--server-addr localhost \
|
28 |
+
--model-name f5_tts \
|
29 |
+
--num-tasks $num_task \
|
30 |
+
--huggingface-dataset yuekai/seed_tts \
|
31 |
+
--split-name test_zh \
|
32 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
33 |
+
|
34 |
+
# For offline Spark-TTS-0.5B
|
35 |
+
python3 client_grpc.py \
|
36 |
+
--server-addr localhost \
|
37 |
+
--model-name spark_tts \
|
38 |
+
--num-tasks $num_task \
|
39 |
+
--huggingface-dataset yuekai/seed_tts \
|
40 |
+
--split-name wenetspeech4tts \
|
41 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
42 |
+
"""
|
43 |
+
|
44 |
+
import argparse
|
45 |
+
import asyncio
|
46 |
+
import json
|
47 |
+
import queue # Added
|
48 |
+
import uuid # Added
|
49 |
+
import functools # Added
|
50 |
+
|
51 |
+
import os
|
52 |
+
import time
|
53 |
+
import types
|
54 |
+
from pathlib import Path
|
55 |
+
|
56 |
+
import numpy as np
|
57 |
+
import soundfile as sf
|
58 |
+
import tritonclient
|
59 |
+
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
60 |
+
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
61 |
+
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
62 |
+
|
63 |
+
|
64 |
+
# --- Added UserData and callback ---
|
65 |
+
class UserData:
|
66 |
+
def __init__(self):
|
67 |
+
self._completed_requests = queue.Queue()
|
68 |
+
self._first_chunk_time = None
|
69 |
+
self._start_time = None
|
70 |
+
|
71 |
+
def record_start_time(self):
|
72 |
+
self._start_time = time.time()
|
73 |
+
|
74 |
+
def get_first_chunk_latency(self):
|
75 |
+
if self._first_chunk_time and self._start_time:
|
76 |
+
return self._first_chunk_time - self._start_time
|
77 |
+
return None
|
78 |
+
|
79 |
+
def callback(user_data, result, error):
|
80 |
+
if user_data._first_chunk_time is None and not error:
|
81 |
+
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
82 |
+
if error:
|
83 |
+
user_data._completed_requests.put(error)
|
84 |
+
else:
|
85 |
+
user_data._completed_requests.put(result)
|
86 |
+
# --- End Added UserData and callback ---
|
87 |
+
|
88 |
+
|
89 |
+
def write_triton_stats(stats, summary_file):
|
90 |
+
with open(summary_file, "w") as summary_f:
|
91 |
+
model_stats = stats["model_stats"]
|
92 |
+
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
93 |
+
summary_f.write(
|
94 |
+
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
95 |
+
)
|
96 |
+
summary_f.write("To learn more about the log, please refer to: \n")
|
97 |
+
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
98 |
+
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
99 |
+
summary_f.write(
|
100 |
+
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
101 |
+
)
|
102 |
+
summary_f.write(
|
103 |
+
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
104 |
+
)
|
105 |
+
summary_f.write(
|
106 |
+
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
107 |
+
)
|
108 |
+
summary_f.write(
|
109 |
+
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
110 |
+
)
|
111 |
+
for model_state in model_stats:
|
112 |
+
if "last_inference" not in model_state:
|
113 |
+
continue
|
114 |
+
summary_f.write(f"model name is {model_state['name']} \n")
|
115 |
+
model_inference_stats = model_state["inference_stats"]
|
116 |
+
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
117 |
+
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
118 |
+
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
119 |
+
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
120 |
+
summary_f.write(
|
121 |
+
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
122 |
+
)
|
123 |
+
model_batch_stats = model_state["batch_stats"]
|
124 |
+
for batch in model_batch_stats:
|
125 |
+
batch_size = int(batch["batch_size"])
|
126 |
+
compute_input = batch["compute_input"]
|
127 |
+
compute_output = batch["compute_output"]
|
128 |
+
compute_infer = batch["compute_infer"]
|
129 |
+
batch_count = int(compute_infer["count"])
|
130 |
+
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
131 |
+
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
132 |
+
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
133 |
+
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
134 |
+
summary_f.write(
|
135 |
+
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
136 |
+
)
|
137 |
+
summary_f.write(
|
138 |
+
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
139 |
+
)
|
140 |
+
summary_f.write(
|
141 |
+
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def get_args():
|
146 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--server-addr",
|
150 |
+
type=str,
|
151 |
+
default="localhost",
|
152 |
+
help="Address of the server",
|
153 |
+
)
|
154 |
+
|
155 |
+
parser.add_argument(
|
156 |
+
"--server-port",
|
157 |
+
type=int,
|
158 |
+
default=8001,
|
159 |
+
help="Grpc port of the triton server, default is 8001",
|
160 |
+
)
|
161 |
+
|
162 |
+
parser.add_argument(
|
163 |
+
"--reference-audio",
|
164 |
+
type=str,
|
165 |
+
default=None,
|
166 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--reference-text",
|
171 |
+
type=str,
|
172 |
+
default="",
|
173 |
+
help="",
|
174 |
+
)
|
175 |
+
|
176 |
+
parser.add_argument(
|
177 |
+
"--target-text",
|
178 |
+
type=str,
|
179 |
+
default="",
|
180 |
+
help="",
|
181 |
+
)
|
182 |
+
|
183 |
+
parser.add_argument(
|
184 |
+
"--huggingface-dataset",
|
185 |
+
type=str,
|
186 |
+
default="yuekai/seed_tts",
|
187 |
+
help="dataset name in huggingface dataset hub",
|
188 |
+
)
|
189 |
+
|
190 |
+
parser.add_argument(
|
191 |
+
"--split-name",
|
192 |
+
type=str,
|
193 |
+
default="wenetspeech4tts",
|
194 |
+
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
195 |
+
help="dataset split name, default is 'test'",
|
196 |
+
)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--manifest-path",
|
200 |
+
type=str,
|
201 |
+
default=None,
|
202 |
+
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
203 |
+
)
|
204 |
+
|
205 |
+
parser.add_argument(
|
206 |
+
"--model-name",
|
207 |
+
type=str,
|
208 |
+
default="f5_tts",
|
209 |
+
choices=["f5_tts", "spark_tts"],
|
210 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
211 |
+
)
|
212 |
+
|
213 |
+
parser.add_argument(
|
214 |
+
"--num-tasks",
|
215 |
+
type=int,
|
216 |
+
default=1,
|
217 |
+
help="Number of concurrent tasks for sending",
|
218 |
+
)
|
219 |
+
|
220 |
+
parser.add_argument(
|
221 |
+
"--log-interval",
|
222 |
+
type=int,
|
223 |
+
default=5,
|
224 |
+
help="Controls how frequently we print the log.",
|
225 |
+
)
|
226 |
+
|
227 |
+
parser.add_argument(
|
228 |
+
"--compute-wer",
|
229 |
+
action="store_true",
|
230 |
+
default=False,
|
231 |
+
help="""True to compute WER.
|
232 |
+
""",
|
233 |
+
)
|
234 |
+
|
235 |
+
parser.add_argument(
|
236 |
+
"--log-dir",
|
237 |
+
type=str,
|
238 |
+
required=False,
|
239 |
+
default="./tmp",
|
240 |
+
help="log directory",
|
241 |
+
)
|
242 |
+
|
243 |
+
# --- Added arguments ---
|
244 |
+
parser.add_argument(
|
245 |
+
"--mode",
|
246 |
+
type=str,
|
247 |
+
default="offline",
|
248 |
+
choices=["offline", "streaming"],
|
249 |
+
help="Select offline or streaming benchmark mode."
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--chunk-overlap-duration",
|
253 |
+
type=float,
|
254 |
+
default=0.1,
|
255 |
+
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
256 |
+
)
|
257 |
+
# --- End Added arguments ---
|
258 |
+
|
259 |
+
return parser.parse_args()
|
260 |
+
|
261 |
+
|
262 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
263 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
264 |
+
if isinstance(wav_path, dict):
|
265 |
+
waveform = wav_path["array"]
|
266 |
+
sample_rate = wav_path["sampling_rate"]
|
267 |
+
else:
|
268 |
+
waveform, sample_rate = sf.read(wav_path)
|
269 |
+
if sample_rate != target_sample_rate:
|
270 |
+
from scipy.signal import resample
|
271 |
+
|
272 |
+
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
273 |
+
waveform = resample(waveform, num_samples)
|
274 |
+
return waveform, target_sample_rate
|
275 |
+
|
276 |
+
def prepare_request_input_output(
|
277 |
+
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
278 |
+
waveform,
|
279 |
+
reference_text,
|
280 |
+
target_text,
|
281 |
+
sample_rate=16000,
|
282 |
+
padding_duration: int = None # Optional padding for offline mode
|
283 |
+
):
|
284 |
+
"""Prepares inputs for Triton inference (offline or streaming)."""
|
285 |
+
assert len(waveform.shape) == 1, "waveform should be 1D"
|
286 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
287 |
+
|
288 |
+
# Apply padding only if padding_duration is provided (for offline)
|
289 |
+
if padding_duration:
|
290 |
+
duration = len(waveform) / sample_rate
|
291 |
+
# Estimate target duration based on text length ratio (crude estimation)
|
292 |
+
# Avoid division by zero if reference_text is empty
|
293 |
+
if reference_text:
|
294 |
+
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
295 |
+
else:
|
296 |
+
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
297 |
+
|
298 |
+
# Calculate required samples based on estimated total duration
|
299 |
+
required_total_samples = padding_duration * sample_rate * (
|
300 |
+
(int(estimated_target_duration + duration) // padding_duration) + 1
|
301 |
+
)
|
302 |
+
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
303 |
+
samples[0, : len(waveform)] = waveform
|
304 |
+
else:
|
305 |
+
# No padding for streaming or if padding_duration is None
|
306 |
+
samples = waveform.reshape(1, -1).astype(np.float32)
|
307 |
+
|
308 |
+
# Common input creation logic
|
309 |
+
inputs = [
|
310 |
+
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
311 |
+
protocol_client.InferInput(
|
312 |
+
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
313 |
+
),
|
314 |
+
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
315 |
+
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
316 |
+
]
|
317 |
+
inputs[0].set_data_from_numpy(samples)
|
318 |
+
inputs[1].set_data_from_numpy(lengths)
|
319 |
+
|
320 |
+
input_data_numpy = np.array([reference_text], dtype=object)
|
321 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
322 |
+
inputs[2].set_data_from_numpy(input_data_numpy)
|
323 |
+
|
324 |
+
input_data_numpy = np.array([target_text], dtype=object)
|
325 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
326 |
+
inputs[3].set_data_from_numpy(input_data_numpy)
|
327 |
+
|
328 |
+
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
329 |
+
|
330 |
+
return inputs, outputs
|
331 |
+
|
332 |
+
def run_sync_streaming_inference(
|
333 |
+
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
334 |
+
model_name: str,
|
335 |
+
inputs: list,
|
336 |
+
outputs: list,
|
337 |
+
request_id: str,
|
338 |
+
user_data: UserData,
|
339 |
+
chunk_overlap_duration: float,
|
340 |
+
save_sample_rate: int,
|
341 |
+
audio_save_path: str,
|
342 |
+
):
|
343 |
+
"""Helper function to run the blocking sync streaming call."""
|
344 |
+
start_time_total = time.time()
|
345 |
+
user_data.record_start_time() # Record start time for first chunk latency calculation
|
346 |
+
|
347 |
+
# Establish stream
|
348 |
+
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
349 |
+
|
350 |
+
# Send request
|
351 |
+
sync_triton_client.async_stream_infer(
|
352 |
+
model_name,
|
353 |
+
inputs,
|
354 |
+
request_id=request_id,
|
355 |
+
outputs=outputs,
|
356 |
+
enable_empty_final_response=True,
|
357 |
+
)
|
358 |
+
|
359 |
+
# Process results
|
360 |
+
audios = []
|
361 |
+
while True:
|
362 |
+
try:
|
363 |
+
result = user_data._completed_requests.get() # Add timeout
|
364 |
+
if isinstance(result, InferenceServerException):
|
365 |
+
print(f"Received InferenceServerException: {result}")
|
366 |
+
sync_triton_client.stop_stream()
|
367 |
+
return None, None, None # Indicate error
|
368 |
+
# Get response metadata
|
369 |
+
response = result.get_response()
|
370 |
+
final = response.parameters["triton_final_response"].bool_param
|
371 |
+
if final is True:
|
372 |
+
break
|
373 |
+
|
374 |
+
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
375 |
+
if audio_chunk.size > 0: # Only append non-empty chunks
|
376 |
+
audios.append(audio_chunk)
|
377 |
+
else:
|
378 |
+
print("Warning: received empty audio chunk.")
|
379 |
+
|
380 |
+
except queue.Empty:
|
381 |
+
print(f"Timeout waiting for response for request id {request_id}")
|
382 |
+
sync_triton_client.stop_stream()
|
383 |
+
return None, None, None # Indicate error
|
384 |
+
|
385 |
+
sync_triton_client.stop_stream()
|
386 |
+
end_time_total = time.time()
|
387 |
+
total_request_latency = end_time_total - start_time_total
|
388 |
+
first_chunk_latency = user_data.get_first_chunk_latency()
|
389 |
+
|
390 |
+
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
391 |
+
actual_duration = 0
|
392 |
+
if audios:
|
393 |
+
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
394 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
395 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
396 |
+
reconstructed_audio = None
|
397 |
+
|
398 |
+
# Simplified reconstruction based on client_grpc_streaming.py
|
399 |
+
if not audios:
|
400 |
+
print("Warning: No audio chunks received.")
|
401 |
+
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
402 |
+
elif len(audios) == 1:
|
403 |
+
reconstructed_audio = audios[0]
|
404 |
+
else:
|
405 |
+
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
406 |
+
for i in range(1, len(audios)):
|
407 |
+
# Cross-fade section
|
408 |
+
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
409 |
+
audios[i - 1][-cross_fade_samples:] * fade_out)
|
410 |
+
# Middle section of the current chunk
|
411 |
+
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
412 |
+
# Concatenate
|
413 |
+
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
414 |
+
# Add the last part of the final chunk
|
415 |
+
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
416 |
+
|
417 |
+
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
418 |
+
actual_duration = len(reconstructed_audio) / save_sample_rate
|
419 |
+
# Save reconstructed audio
|
420 |
+
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
|
421 |
+
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
422 |
+
else:
|
423 |
+
print("Warning: No audio chunks received or reconstructed.")
|
424 |
+
actual_duration = 0 # Set duration to 0 if no audio
|
425 |
+
|
426 |
+
else:
|
427 |
+
print("Warning: No audio chunks received.")
|
428 |
+
actual_duration = 0
|
429 |
+
|
430 |
+
return total_request_latency, first_chunk_latency, actual_duration
|
431 |
+
|
432 |
+
|
433 |
+
async def send_streaming(
|
434 |
+
manifest_item_list: list,
|
435 |
+
name: str,
|
436 |
+
server_url: str, # Changed from sync_triton_client
|
437 |
+
protocol_client: types.ModuleType,
|
438 |
+
log_interval: int,
|
439 |
+
model_name: str,
|
440 |
+
audio_save_dir: str = "./",
|
441 |
+
save_sample_rate: int = 16000,
|
442 |
+
chunk_overlap_duration: float = 0.1,
|
443 |
+
padding_duration: int = None,
|
444 |
+
):
|
445 |
+
total_duration = 0.0
|
446 |
+
latency_data = []
|
447 |
+
task_id = int(name[5:])
|
448 |
+
sync_triton_client = None # Initialize client variable
|
449 |
+
|
450 |
+
try: # Wrap in try...finally to ensure client closing
|
451 |
+
print(f"{name}: Initializing sync client for streaming...")
|
452 |
+
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
453 |
+
|
454 |
+
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
455 |
+
for i, item in enumerate(manifest_item_list):
|
456 |
+
if i % log_interval == 0:
|
457 |
+
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
|
458 |
+
|
459 |
+
try:
|
460 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
461 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
462 |
+
|
463 |
+
inputs, outputs = prepare_request_input_output(
|
464 |
+
protocol_client,
|
465 |
+
waveform,
|
466 |
+
reference_text,
|
467 |
+
target_text,
|
468 |
+
sample_rate,
|
469 |
+
padding_duration=padding_duration
|
470 |
+
)
|
471 |
+
request_id = str(uuid.uuid4())
|
472 |
+
user_data = UserData()
|
473 |
+
|
474 |
+
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
475 |
+
|
476 |
+
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
477 |
+
run_sync_streaming_inference,
|
478 |
+
sync_triton_client,
|
479 |
+
model_name,
|
480 |
+
inputs,
|
481 |
+
outputs,
|
482 |
+
request_id,
|
483 |
+
user_data,
|
484 |
+
chunk_overlap_duration,
|
485 |
+
save_sample_rate,
|
486 |
+
audio_save_path
|
487 |
+
)
|
488 |
+
|
489 |
+
if total_request_latency is not None:
|
490 |
+
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
491 |
+
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
492 |
+
total_duration += actual_duration
|
493 |
+
else:
|
494 |
+
print(f"{name}: Item {i} failed.")
|
495 |
+
|
496 |
+
|
497 |
+
except FileNotFoundError:
|
498 |
+
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
499 |
+
except Exception as e:
|
500 |
+
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
|
501 |
+
import traceback
|
502 |
+
traceback.print_exc()
|
503 |
+
|
504 |
+
|
505 |
+
finally: # Ensure client is closed
|
506 |
+
if sync_triton_client:
|
507 |
+
try:
|
508 |
+
print(f"{name}: Closing sync client...")
|
509 |
+
sync_triton_client.close()
|
510 |
+
except Exception as e:
|
511 |
+
print(f"{name}: Error closing sync client: {e}")
|
512 |
+
|
513 |
+
|
514 |
+
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
515 |
+
return total_duration, latency_data
|
516 |
+
|
517 |
+
async def send(
|
518 |
+
manifest_item_list: list,
|
519 |
+
name: str,
|
520 |
+
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
521 |
+
protocol_client: types.ModuleType,
|
522 |
+
log_interval: int,
|
523 |
+
model_name: str,
|
524 |
+
padding_duration: int = None,
|
525 |
+
audio_save_dir: str = "./",
|
526 |
+
save_sample_rate: int = 16000,
|
527 |
+
):
|
528 |
+
total_duration = 0.0
|
529 |
+
latency_data = []
|
530 |
+
task_id = int(name[5:])
|
531 |
+
|
532 |
+
print(f"manifest_item_list: {manifest_item_list}")
|
533 |
+
for i, item in enumerate(manifest_item_list):
|
534 |
+
if i % log_interval == 0:
|
535 |
+
print(f"{name}: {i}/{len(manifest_item_list)}")
|
536 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
537 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
538 |
+
|
539 |
+
inputs, outputs = prepare_request_input_output(
|
540 |
+
protocol_client,
|
541 |
+
waveform,
|
542 |
+
reference_text,
|
543 |
+
target_text,
|
544 |
+
sample_rate,
|
545 |
+
padding_duration=padding_duration
|
546 |
+
)
|
547 |
+
sequence_id = 100000000 + i + task_id * 10
|
548 |
+
start = time.time()
|
549 |
+
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
550 |
+
|
551 |
+
audio = response.as_numpy("waveform").reshape(-1)
|
552 |
+
actual_duration = len(audio) / save_sample_rate
|
553 |
+
|
554 |
+
end = time.time() - start
|
555 |
+
|
556 |
+
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
557 |
+
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
558 |
+
|
559 |
+
latency_data.append((end, actual_duration))
|
560 |
+
total_duration += actual_duration
|
561 |
+
|
562 |
+
return total_duration, latency_data
|
563 |
+
|
564 |
+
|
565 |
+
def load_manifests(manifest_path):
|
566 |
+
with open(manifest_path, "r") as f:
|
567 |
+
manifest_list = []
|
568 |
+
for line in f:
|
569 |
+
assert len(line.strip().split("|")) == 4
|
570 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
571 |
+
utt = Path(utt).stem
|
572 |
+
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
573 |
+
if not os.path.isabs(prompt_wav):
|
574 |
+
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
575 |
+
manifest_list.append(
|
576 |
+
{
|
577 |
+
"audio_filepath": prompt_wav,
|
578 |
+
"reference_text": prompt_text,
|
579 |
+
"target_text": gt_text,
|
580 |
+
"target_audio_path": utt,
|
581 |
+
}
|
582 |
+
)
|
583 |
+
return manifest_list
|
584 |
+
|
585 |
+
|
586 |
+
def split_data(data, k):
|
587 |
+
n = len(data)
|
588 |
+
if n < k:
|
589 |
+
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
590 |
+
k = n
|
591 |
+
|
592 |
+
quotient = n // k
|
593 |
+
remainder = n % k
|
594 |
+
|
595 |
+
result = []
|
596 |
+
start = 0
|
597 |
+
for i in range(k):
|
598 |
+
if i < remainder:
|
599 |
+
end = start + quotient + 1
|
600 |
+
else:
|
601 |
+
end = start + quotient
|
602 |
+
|
603 |
+
result.append(data[start:end])
|
604 |
+
start = end
|
605 |
+
|
606 |
+
return result
|
607 |
+
|
608 |
+
async def main():
|
609 |
+
args = get_args()
|
610 |
+
url = f"{args.server_addr}:{args.server_port}"
|
611 |
+
|
612 |
+
# --- Client Initialization based on mode ---
|
613 |
+
triton_client = None
|
614 |
+
protocol_client = None
|
615 |
+
if args.mode == "offline":
|
616 |
+
print("Initializing gRPC client for offline mode...")
|
617 |
+
# Use the async client for offline tasks
|
618 |
+
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
619 |
+
protocol_client = grpcclient_aio
|
620 |
+
elif args.mode == "streaming":
|
621 |
+
print("Initializing gRPC client for streaming mode...")
|
622 |
+
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
623 |
+
# We will create one sync client instance PER TASK inside send_streaming.
|
624 |
+
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
625 |
+
protocol_client = grpcclient_sync # protocol client for input prep
|
626 |
+
else:
|
627 |
+
raise ValueError(f"Invalid mode: {args.mode}")
|
628 |
+
# --- End Client Initialization ---
|
629 |
+
|
630 |
+
if args.reference_audio:
|
631 |
+
args.num_tasks = 1
|
632 |
+
args.log_interval = 1
|
633 |
+
manifest_item_list = [
|
634 |
+
{
|
635 |
+
"reference_text": args.reference_text,
|
636 |
+
"target_text": args.target_text,
|
637 |
+
"audio_filepath": args.reference_audio,
|
638 |
+
"target_audio_path": "test",
|
639 |
+
}
|
640 |
+
]
|
641 |
+
elif args.huggingface_dataset:
|
642 |
+
import datasets
|
643 |
+
|
644 |
+
dataset = datasets.load_dataset(
|
645 |
+
args.huggingface_dataset,
|
646 |
+
split=args.split_name,
|
647 |
+
trust_remote_code=True,
|
648 |
+
)
|
649 |
+
manifest_item_list = []
|
650 |
+
for i in range(len(dataset)):
|
651 |
+
manifest_item_list.append(
|
652 |
+
{
|
653 |
+
"audio_filepath": dataset[i]["prompt_audio"],
|
654 |
+
"reference_text": dataset[i]["prompt_text"],
|
655 |
+
"target_audio_path": dataset[i]["id"],
|
656 |
+
"target_text": dataset[i]["target_text"],
|
657 |
+
}
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
manifest_item_list = load_manifests(args.manifest_path)
|
661 |
+
|
662 |
+
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
663 |
+
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
664 |
+
|
665 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
666 |
+
tasks = []
|
667 |
+
start_time = time.time()
|
668 |
+
for i in range(num_tasks):
|
669 |
+
# --- Task Creation based on mode ---
|
670 |
+
if args.mode == "offline":
|
671 |
+
task = asyncio.create_task(
|
672 |
+
send(
|
673 |
+
manifest_item_list[i],
|
674 |
+
name=f"task-{i}",
|
675 |
+
triton_client=triton_client,
|
676 |
+
protocol_client=protocol_client,
|
677 |
+
log_interval=args.log_interval,
|
678 |
+
model_name=args.model_name,
|
679 |
+
audio_save_dir=args.log_dir,
|
680 |
+
padding_duration=1,
|
681 |
+
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
682 |
+
)
|
683 |
+
)
|
684 |
+
elif args.mode == "streaming":
|
685 |
+
task = asyncio.create_task(
|
686 |
+
send_streaming(
|
687 |
+
manifest_item_list[i],
|
688 |
+
name=f"task-{i}",
|
689 |
+
server_url=url, # Pass URL instead of client
|
690 |
+
protocol_client=protocol_client,
|
691 |
+
log_interval=args.log_interval,
|
692 |
+
model_name=args.model_name,
|
693 |
+
audio_save_dir=args.log_dir,
|
694 |
+
padding_duration=10,
|
695 |
+
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
696 |
+
chunk_overlap_duration=args.chunk_overlap_duration,
|
697 |
+
)
|
698 |
+
)
|
699 |
+
# --- End Task Creation ---
|
700 |
+
tasks.append(task)
|
701 |
+
|
702 |
+
ans_list = await asyncio.gather(*tasks)
|
703 |
+
|
704 |
+
end_time = time.time()
|
705 |
+
elapsed = end_time - start_time
|
706 |
+
|
707 |
+
total_duration = 0.0
|
708 |
+
latency_data = []
|
709 |
+
for ans in ans_list:
|
710 |
+
if ans:
|
711 |
+
total_duration += ans[0]
|
712 |
+
latency_data.extend(ans[1]) # Use extend for list of lists
|
713 |
+
else:
|
714 |
+
print("Warning: A task returned None, possibly due to an error.")
|
715 |
+
|
716 |
+
|
717 |
+
if total_duration == 0:
|
718 |
+
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
719 |
+
rtf = float('inf')
|
720 |
+
else:
|
721 |
+
rtf = elapsed / total_duration
|
722 |
+
|
723 |
+
s = f"Mode: {args.mode}\n"
|
724 |
+
s += f"RTF: {rtf:.4f}\n"
|
725 |
+
s += f"total_duration: {total_duration:.3f} seconds\n"
|
726 |
+
s += f"({total_duration / 3600:.2f} hours)\n"
|
727 |
+
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
728 |
+
|
729 |
+
# --- Statistics Reporting based on mode ---
|
730 |
+
if latency_data:
|
731 |
+
if args.mode == "offline":
|
732 |
+
# Original offline latency calculation
|
733 |
+
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
734 |
+
if latency_list:
|
735 |
+
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
736 |
+
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
737 |
+
s += f"latency_variance: {latency_variance:.2f}\n"
|
738 |
+
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
739 |
+
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
740 |
+
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
741 |
+
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
742 |
+
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
743 |
+
else:
|
744 |
+
s += "No latency data collected for offline mode.\n"
|
745 |
+
|
746 |
+
elif args.mode == "streaming":
|
747 |
+
# Calculate stats for total request latency and first chunk latency
|
748 |
+
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
749 |
+
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
750 |
+
|
751 |
+
s += "\n--- Total Request Latency ---\n"
|
752 |
+
if total_latency_list:
|
753 |
+
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
|
754 |
+
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
|
755 |
+
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
|
756 |
+
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
|
757 |
+
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
|
758 |
+
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
|
759 |
+
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
760 |
+
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
761 |
+
else:
|
762 |
+
s += "No total request latency data collected.\n"
|
763 |
+
|
764 |
+
s += "\n--- First Chunk Latency ---\n"
|
765 |
+
if first_chunk_latency_list:
|
766 |
+
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
|
767 |
+
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
|
768 |
+
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
|
769 |
+
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
770 |
+
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
771 |
+
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
772 |
+
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
773 |
+
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
774 |
+
else:
|
775 |
+
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
776 |
+
else:
|
777 |
+
s += "No latency data collected.\n"
|
778 |
+
# --- End Statistics Reporting ---
|
779 |
+
|
780 |
+
print(s)
|
781 |
+
if args.manifest_path:
|
782 |
+
name = Path(args.manifest_path).stem
|
783 |
+
elif args.split_name:
|
784 |
+
name = args.split_name
|
785 |
+
elif args.reference_audio:
|
786 |
+
name = Path(args.reference_audio).stem
|
787 |
+
else:
|
788 |
+
name = "results" # Default name if no manifest/split/audio provided
|
789 |
+
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
790 |
+
f.write(s)
|
791 |
+
|
792 |
+
# --- Statistics Fetching using temporary Async Client ---
|
793 |
+
# Use a separate async client for fetching stats regardless of mode
|
794 |
+
stats_client = None
|
795 |
+
try:
|
796 |
+
print("Initializing temporary async client for fetching stats...")
|
797 |
+
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
798 |
+
print("Fetching inference statistics...")
|
799 |
+
# Fetching for all models, filtering might be needed depending on server setup
|
800 |
+
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
801 |
+
print("Fetching model config...")
|
802 |
+
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
803 |
+
|
804 |
+
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
805 |
+
|
806 |
+
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
807 |
+
json.dump(metadata, f, indent=4)
|
808 |
+
|
809 |
+
except Exception as e:
|
810 |
+
print(f"Could not retrieve statistics or config: {e}")
|
811 |
+
finally:
|
812 |
+
if stats_client:
|
813 |
+
try:
|
814 |
+
print("Closing temporary async stats client...")
|
815 |
+
await stats_client.close()
|
816 |
+
except Exception as e:
|
817 |
+
print(f"Error closing async stats client: {e}")
|
818 |
+
# --- End Statistics Fetching ---
|
819 |
+
|
820 |
+
|
821 |
+
if __name__ == "__main__":
|
822 |
+
# asyncio.run(main()) # Use TaskGroup for better exception handling if needed
|
823 |
+
async def run_main():
|
824 |
+
try:
|
825 |
+
await main()
|
826 |
+
except Exception as e:
|
827 |
+
print(f"An error occurred in main: {e}")
|
828 |
+
import traceback
|
829 |
+
traceback.print_exc()
|
830 |
+
|
831 |
+
asyncio.run(run_main())
|
runtime/triton_trtllm/client_http.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import requests
|
27 |
+
import soundfile as sf
|
28 |
+
import json
|
29 |
+
import numpy as np
|
30 |
+
import argparse
|
31 |
+
|
32 |
+
def get_args():
|
33 |
+
parser = argparse.ArgumentParser(
|
34 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--server-url",
|
39 |
+
type=str,
|
40 |
+
default="localhost:8000",
|
41 |
+
help="Address of the server",
|
42 |
+
)
|
43 |
+
|
44 |
+
parser.add_argument(
|
45 |
+
"--reference-audio",
|
46 |
+
type=str,
|
47 |
+
default="../../example/prompt_audio.wav",
|
48 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
49 |
+
)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--reference-text",
|
53 |
+
type=str,
|
54 |
+
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
|
55 |
+
help="",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--target-text",
|
60 |
+
type=str,
|
61 |
+
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
|
62 |
+
help="",
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"--model-name",
|
67 |
+
type=str,
|
68 |
+
default="spark_tts",
|
69 |
+
choices=[
|
70 |
+
"f5_tts", "spark_tts"
|
71 |
+
],
|
72 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--output-audio",
|
77 |
+
type=str,
|
78 |
+
default="output.wav",
|
79 |
+
help="Path to save the output audio",
|
80 |
+
)
|
81 |
+
return parser.parse_args()
|
82 |
+
|
83 |
+
def prepare_request(
|
84 |
+
waveform,
|
85 |
+
reference_text,
|
86 |
+
target_text,
|
87 |
+
sample_rate=16000,
|
88 |
+
padding_duration: int = None,
|
89 |
+
audio_save_dir: str = "./",
|
90 |
+
):
|
91 |
+
assert len(waveform.shape) == 1, "waveform should be 1D"
|
92 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
93 |
+
if padding_duration:
|
94 |
+
# padding to nearset 10 seconds
|
95 |
+
samples = np.zeros(
|
96 |
+
(
|
97 |
+
1,
|
98 |
+
padding_duration
|
99 |
+
* sample_rate
|
100 |
+
* ((int(duration) // padding_duration) + 1),
|
101 |
+
),
|
102 |
+
dtype=np.float32,
|
103 |
+
)
|
104 |
+
|
105 |
+
samples[0, : len(waveform)] = waveform
|
106 |
+
else:
|
107 |
+
samples = waveform
|
108 |
+
|
109 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
110 |
+
|
111 |
+
data = {
|
112 |
+
"inputs":[
|
113 |
+
{
|
114 |
+
"name": "reference_wav",
|
115 |
+
"shape": samples.shape,
|
116 |
+
"datatype": "FP32",
|
117 |
+
"data": samples.tolist()
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"name": "reference_wav_len",
|
121 |
+
"shape": lengths.shape,
|
122 |
+
"datatype": "INT32",
|
123 |
+
"data": lengths.tolist(),
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"name": "reference_text",
|
127 |
+
"shape": [1, 1],
|
128 |
+
"datatype": "BYTES",
|
129 |
+
"data": [reference_text]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "target_text",
|
133 |
+
"shape": [1, 1],
|
134 |
+
"datatype": "BYTES",
|
135 |
+
"data": [target_text]
|
136 |
+
}
|
137 |
+
]
|
138 |
+
}
|
139 |
+
|
140 |
+
return data
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
args = get_args()
|
144 |
+
server_url = args.server_url
|
145 |
+
if not server_url.startswith(("http://", "https://")):
|
146 |
+
server_url = f"http://{server_url}"
|
147 |
+
|
148 |
+
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
149 |
+
waveform, sr = sf.read(args.reference_audio)
|
150 |
+
assert sr == 16000, "sample rate hardcoded in server"
|
151 |
+
|
152 |
+
samples = np.array(waveform, dtype=np.float32)
|
153 |
+
data = prepare_request(samples, args.reference_text, args.target_text)
|
154 |
+
|
155 |
+
rsp = requests.post(
|
156 |
+
url,
|
157 |
+
headers={"Content-Type": "application/json"},
|
158 |
+
json=data,
|
159 |
+
verify=False,
|
160 |
+
params={"request_id": '0'}
|
161 |
+
)
|
162 |
+
result = rsp.json()
|
163 |
+
audio = result["outputs"][0]["data"]
|
164 |
+
audio = np.array(audio, dtype=np.float32)
|
165 |
+
sf.write(args.output_audio, audio, 16000, "PCM_16")
|
runtime/triton_trtllm/docker-compose.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
tts:
|
3 |
+
image: soar97/triton-spark-tts:25.02
|
4 |
+
shm_size: '1gb'
|
5 |
+
ports:
|
6 |
+
- "8000:8000"
|
7 |
+
- "8001:8001"
|
8 |
+
- "8002:8002"
|
9 |
+
environment:
|
10 |
+
- PYTHONIOENCODING=utf-8
|
11 |
+
- MODEL_ID=${MODEL_ID}
|
12 |
+
deploy:
|
13 |
+
resources:
|
14 |
+
reservations:
|
15 |
+
devices:
|
16 |
+
- driver: nvidia
|
17 |
+
device_ids: ['0']
|
18 |
+
capabilities: [gpu]
|
19 |
+
command: >
|
20 |
+
/bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
|
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import json
|
27 |
+
import torch
|
28 |
+
from torch.utils.dlpack import to_dlpack
|
29 |
+
|
30 |
+
import triton_python_backend_utils as pb_utils
|
31 |
+
|
32 |
+
import os
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
36 |
+
|
37 |
+
class TritonPythonModel:
|
38 |
+
"""Triton Python model for audio tokenization.
|
39 |
+
|
40 |
+
This model takes reference audio input and extracts semantic and global tokens
|
41 |
+
using BiCodec tokenizer.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def initialize(self, args):
|
45 |
+
"""Initialize the model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
args: Dictionary containing model configuration
|
49 |
+
"""
|
50 |
+
# Parse model parameters
|
51 |
+
parameters = json.loads(args['model_config'])['parameters']
|
52 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
53 |
+
|
54 |
+
# Initialize tokenizer
|
55 |
+
self.device = torch.device("cuda")
|
56 |
+
self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
|
57 |
+
device=self.device)
|
58 |
+
|
59 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
60 |
+
"""Extract reference audio clip for speaker embedding.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav: Input waveform array
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Reference clip of fixed duration
|
67 |
+
"""
|
68 |
+
SAMPLE_RATE = 16000
|
69 |
+
REF_SEGMENT_DURATION = 6 # seconds
|
70 |
+
LATENT_HOP_LENGTH = 320
|
71 |
+
|
72 |
+
ref_segment_length = (
|
73 |
+
int(SAMPLE_RATE * REF_SEGMENT_DURATION)
|
74 |
+
// LATENT_HOP_LENGTH
|
75 |
+
* LATENT_HOP_LENGTH
|
76 |
+
)
|
77 |
+
wav_length = len(wav)
|
78 |
+
|
79 |
+
if ref_segment_length > wav_length:
|
80 |
+
# Repeat and truncate if input is too short
|
81 |
+
repeat_times = ref_segment_length // wav_length + 1
|
82 |
+
wav = np.tile(wav, repeat_times)
|
83 |
+
|
84 |
+
return wav[:ref_segment_length]
|
85 |
+
|
86 |
+
def execute(self, requests):
|
87 |
+
"""Execute inference on the batched requests.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
requests: List of inference requests
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
List of inference responses containing tokenized outputs
|
94 |
+
"""
|
95 |
+
reference_wav_list = []
|
96 |
+
reference_wav_ref_clip_list = []
|
97 |
+
|
98 |
+
# Process each request in batch
|
99 |
+
for request in requests:
|
100 |
+
# Extract input tensors
|
101 |
+
wav_array = pb_utils.get_input_tensor_by_name(
|
102 |
+
request, "reference_wav").as_numpy()
|
103 |
+
wav_len = pb_utils.get_input_tensor_by_name(
|
104 |
+
request, "reference_wav_len").as_numpy().item()
|
105 |
+
|
106 |
+
# Prepare inputs
|
107 |
+
wav = wav_array[:, :wav_len].squeeze(0)
|
108 |
+
reference_wav_list.append(wav)
|
109 |
+
|
110 |
+
wav_ref_clip = self.get_ref_clip(wav)
|
111 |
+
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
|
112 |
+
|
113 |
+
# Batch process through tokenizer
|
114 |
+
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
|
115 |
+
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
|
116 |
+
reference_wav_list)
|
117 |
+
|
118 |
+
audio_tokenizer_input = {
|
119 |
+
"ref_wav": ref_wav_clip_tensor.to(self.device),
|
120 |
+
"feat": wav2vec2_features.to(self.device),
|
121 |
+
}
|
122 |
+
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
|
123 |
+
audio_tokenizer_input)
|
124 |
+
|
125 |
+
# Prepare responses
|
126 |
+
responses = []
|
127 |
+
for i in range(len(requests)):
|
128 |
+
global_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
129 |
+
"global_tokens", to_dlpack(global_tokens[i]))
|
130 |
+
semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
131 |
+
"semantic_tokens", to_dlpack(semantic_tokens[i]))
|
132 |
+
|
133 |
+
inference_response = pb_utils.InferenceResponse(
|
134 |
+
output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
|
135 |
+
responses.append(inference_response)
|
136 |
+
|
137 |
+
return responses
|
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "audio_tokenizer"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "reference_wav"
|
31 |
+
data_type: TYPE_FP32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "reference_wav_len"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "global_tokens"
|
43 |
+
data_type: TYPE_INT32
|
44 |
+
dims: [-1]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
name: "semantic_tokens"
|
48 |
+
data_type: TYPE_INT32
|
49 |
+
dims: [-1]
|
50 |
+
}
|
51 |
+
]
|
52 |
+
|
53 |
+
instance_group [
|
54 |
+
{
|
55 |
+
count: 1
|
56 |
+
kind: KIND_CPU
|
57 |
+
}
|
58 |
+
]
|
runtime/triton_trtllm/model_repo/spark_tts/1/model.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import math
|
29 |
+
import os
|
30 |
+
import re
|
31 |
+
from typing import Dict, List, Tuple, Optional, Union
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
36 |
+
import triton_python_backend_utils as pb_utils
|
37 |
+
from transformers import AutoTokenizer
|
38 |
+
|
39 |
+
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
40 |
+
|
41 |
+
def process_prompt(
|
42 |
+
text: str,
|
43 |
+
prompt_text: Optional[str] = None,
|
44 |
+
global_token_ids: torch.Tensor = None,
|
45 |
+
semantic_token_ids: torch.Tensor = None,
|
46 |
+
) -> Tuple[str, torch.Tensor]:
|
47 |
+
"""
|
48 |
+
Process input for voice cloning.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
text: The text input to be converted to speech.
|
52 |
+
prompt_text: Transcript of the prompt audio.
|
53 |
+
global_token_ids: Global token IDs extracted from reference audio.
|
54 |
+
semantic_token_ids: Semantic token IDs extracted from reference audio.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Tuple containing the formatted input prompt and global token IDs.
|
58 |
+
"""
|
59 |
+
# Convert global tokens to string format
|
60 |
+
global_tokens = "".join(
|
61 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
# Prepare the input tokens for the model
|
66 |
+
if prompt_text is not None:
|
67 |
+
# Include semantic tokens when prompt text is provided
|
68 |
+
semantic_tokens = "".join(
|
69 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
70 |
+
)
|
71 |
+
|
72 |
+
inputs = [
|
73 |
+
TASK_TOKEN_MAP["tts"],
|
74 |
+
"<|start_content|>",
|
75 |
+
prompt_text,
|
76 |
+
text,
|
77 |
+
"<|end_content|>",
|
78 |
+
"<|start_global_token|>",
|
79 |
+
global_tokens,
|
80 |
+
"<|end_global_token|>",
|
81 |
+
"<|start_semantic_token|>",
|
82 |
+
semantic_tokens,
|
83 |
+
]
|
84 |
+
else:
|
85 |
+
# Without prompt text, exclude semantic tokens
|
86 |
+
inputs = [
|
87 |
+
TASK_TOKEN_MAP["tts"],
|
88 |
+
"<|start_content|>",
|
89 |
+
text,
|
90 |
+
"<|end_content|>",
|
91 |
+
"<|start_global_token|>",
|
92 |
+
global_tokens,
|
93 |
+
"<|end_global_token|>",
|
94 |
+
]
|
95 |
+
|
96 |
+
# Join all input components into a single string
|
97 |
+
inputs = "".join(inputs)
|
98 |
+
return inputs, global_token_ids
|
99 |
+
|
100 |
+
|
101 |
+
class TritonPythonModel:
|
102 |
+
"""Triton Python model for Spark TTS.
|
103 |
+
|
104 |
+
This model orchestrates the end-to-end TTS pipeline by coordinating
|
105 |
+
between audio tokenizer, LLM, and vocoder components.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def initialize(self, args):
|
109 |
+
"""Initialize the model.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
args: Dictionary containing model configuration
|
113 |
+
"""
|
114 |
+
self.logger = pb_utils.Logger
|
115 |
+
# Parse model parameters
|
116 |
+
self.model_config = json.loads(args['model_config'])
|
117 |
+
parameters = self.model_config['parameters']
|
118 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
119 |
+
self.logger.log_info(f"model_params:{model_params}")
|
120 |
+
# streaming TTS parameters
|
121 |
+
assert (
|
122 |
+
float(model_params["audio_chunk_duration"]) >= 0.5
|
123 |
+
), f"audio_chunk_duration at least 0.5 seconds"
|
124 |
+
self.audio_chunk_duration = float(model_params["audio_chunk_duration"])
|
125 |
+
self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"])
|
126 |
+
assert (
|
127 |
+
float(model_params["audio_chunk_size_scale_factor"]) >= 1.0
|
128 |
+
), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf"
|
129 |
+
self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) # scale speed
|
130 |
+
self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"])
|
131 |
+
self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"])
|
132 |
+
|
133 |
+
# Initialize tokenizer
|
134 |
+
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
135 |
+
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
136 |
+
self.device = torch.device("cuda")
|
137 |
+
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
138 |
+
|
139 |
+
def forward_llm(self, input_ids):
|
140 |
+
"""
|
141 |
+
Prepares the response from the language model based on the provided
|
142 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
143 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
144 |
+
For each response from the language model:
|
145 |
+
- Checks for errors and raise an exception if any are found.
|
146 |
+
- Extracts the "output_ids" tensor from the response.
|
147 |
+
- Determines the finish reason based on the presence of the
|
148 |
+
end-of-sequence token or reaching the maximum length.
|
149 |
+
- Appends the generated token IDs to `output_ids`.
|
150 |
+
- If the finish reason is determined, decodes the output IDs to text
|
151 |
+
and prepares the final response.
|
152 |
+
|
153 |
+
The final response includes the generated text, finish reason,
|
154 |
+
completion tokens, prompt tokens, and total tokens.
|
155 |
+
|
156 |
+
Parameters
|
157 |
+
----------
|
158 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
159 |
+
|
160 |
+
Returns
|
161 |
+
-------
|
162 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
163 |
+
"""
|
164 |
+
# convert input_ids to numpy, with shape [1, sequence_length]
|
165 |
+
input_ids = input_ids.cpu().numpy()
|
166 |
+
max_tokens = 512
|
167 |
+
input_dict = {
|
168 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
169 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
170 |
+
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
|
171 |
+
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
172 |
+
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
173 |
+
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
174 |
+
"temperature": np.array([[0.8]], dtype=np.float32),
|
175 |
+
"input_ids": input_ids,
|
176 |
+
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
177 |
+
}
|
178 |
+
|
179 |
+
# Convert inputs to Triton tensors
|
180 |
+
input_tensor_list = [
|
181 |
+
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
182 |
+
]
|
183 |
+
|
184 |
+
# Create and execute inference request
|
185 |
+
llm_request = pb_utils.InferenceRequest(
|
186 |
+
model_name="tensorrt_llm",
|
187 |
+
requested_output_names=["output_ids", "sequence_length"],
|
188 |
+
inputs=input_tensor_list,
|
189 |
+
)
|
190 |
+
|
191 |
+
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
192 |
+
if self.decoupled:
|
193 |
+
for llm_response in llm_responses:
|
194 |
+
if llm_response.has_error():
|
195 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
196 |
+
|
197 |
+
# Extract and process output
|
198 |
+
output_ids = pb_utils.get_output_tensor_by_name(
|
199 |
+
llm_response, "output_ids").as_numpy()
|
200 |
+
seq_lens = pb_utils.get_output_tensor_by_name(
|
201 |
+
llm_response, "sequence_length").as_numpy()
|
202 |
+
|
203 |
+
# Get actual output IDs up to the sequence length
|
204 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
205 |
+
|
206 |
+
yield actual_output_ids
|
207 |
+
else:
|
208 |
+
llm_response = llm_responses
|
209 |
+
if llm_response.has_error():
|
210 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
211 |
+
|
212 |
+
# Extract and process output
|
213 |
+
output_ids = pb_utils.get_output_tensor_by_name(
|
214 |
+
llm_response, "output_ids").as_numpy()
|
215 |
+
seq_lens = pb_utils.get_output_tensor_by_name(
|
216 |
+
llm_response, "sequence_length").as_numpy()
|
217 |
+
|
218 |
+
# Get actual output IDs up to the sequence length
|
219 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
220 |
+
|
221 |
+
yield actual_output_ids
|
222 |
+
|
223 |
+
def forward_audio_tokenizer(self, wav, wav_len):
|
224 |
+
"""Forward pass through the audio tokenizer component.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
wav: Input waveform tensor
|
228 |
+
wav_len: Waveform length tensor
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Tuple of global and semantic tokens
|
232 |
+
"""
|
233 |
+
inference_request = pb_utils.InferenceRequest(
|
234 |
+
model_name='audio_tokenizer',
|
235 |
+
requested_output_names=['global_tokens', 'semantic_tokens'],
|
236 |
+
inputs=[wav, wav_len]
|
237 |
+
)
|
238 |
+
|
239 |
+
inference_response = inference_request.exec()
|
240 |
+
if inference_response.has_error():
|
241 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
242 |
+
|
243 |
+
# Extract and convert output tensors
|
244 |
+
global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
|
245 |
+
global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
|
246 |
+
|
247 |
+
semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
|
248 |
+
semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
|
249 |
+
|
250 |
+
return global_tokens, semantic_tokens
|
251 |
+
|
252 |
+
def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
|
253 |
+
"""Forward pass through the vocoder component.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
global_token_ids: Global token IDs tensor
|
257 |
+
pred_semantic_ids: Predicted semantic token IDs tensor
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
Generated waveform tensor
|
261 |
+
"""
|
262 |
+
# Convert tensors to Triton format
|
263 |
+
global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
|
264 |
+
pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
|
265 |
+
|
266 |
+
# Create and execute inference request
|
267 |
+
inference_request = pb_utils.InferenceRequest(
|
268 |
+
model_name='vocoder',
|
269 |
+
requested_output_names=['waveform'],
|
270 |
+
inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
|
271 |
+
)
|
272 |
+
|
273 |
+
inference_response = inference_request.exec()
|
274 |
+
if inference_response.has_error():
|
275 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
276 |
+
|
277 |
+
# Extract and convert output waveform
|
278 |
+
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
279 |
+
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
280 |
+
|
281 |
+
return waveform
|
282 |
+
|
283 |
+
def token2wav(self, generated_token_ids, global_token_ids):
|
284 |
+
# Decode and extract semantic token IDs from generated text
|
285 |
+
predicted_text = self.tokenizer.batch_decode(
|
286 |
+
[generated_token_ids],
|
287 |
+
skip_special_tokens=True,
|
288 |
+
)[0]
|
289 |
+
pred_semantic_ids = (
|
290 |
+
torch.tensor(
|
291 |
+
[int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]
|
292 |
+
)
|
293 |
+
.unsqueeze(0)
|
294 |
+
.to(torch.int32)
|
295 |
+
)
|
296 |
+
|
297 |
+
# Generate audio with vocoder
|
298 |
+
audio = self.forward_vocoder(
|
299 |
+
global_token_ids.to(self.device),
|
300 |
+
pred_semantic_ids.to(self.device),
|
301 |
+
)
|
302 |
+
|
303 |
+
return audio
|
304 |
+
|
305 |
+
def execute(self, requests):
|
306 |
+
"""Execute inference on the batched requests.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
requests: List of inference requests
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
List of inference responses containing generated audio
|
313 |
+
"""
|
314 |
+
responses = []
|
315 |
+
|
316 |
+
for request in requests:
|
317 |
+
# Extract input tensors
|
318 |
+
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
319 |
+
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
320 |
+
|
321 |
+
# Process reference audio through audio tokenizer
|
322 |
+
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
323 |
+
|
324 |
+
# Extract text inputs
|
325 |
+
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
326 |
+
reference_text = reference_text[0][0].decode('utf-8')
|
327 |
+
|
328 |
+
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
329 |
+
target_text = target_text[0][0].decode('utf-8')
|
330 |
+
|
331 |
+
# Prepare prompt for LLM
|
332 |
+
prompt, global_token_ids = process_prompt(
|
333 |
+
text=target_text,
|
334 |
+
prompt_text=reference_text,
|
335 |
+
global_token_ids=global_tokens,
|
336 |
+
semantic_token_ids=semantic_tokens,
|
337 |
+
)
|
338 |
+
|
339 |
+
|
340 |
+
# Tokenize prompt for LLM
|
341 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
342 |
+
input_ids = model_inputs.input_ids.to(torch.int32)
|
343 |
+
|
344 |
+
# Generate semantic tokens with LLM
|
345 |
+
generated_ids_iter = self.forward_llm(input_ids)
|
346 |
+
|
347 |
+
if self.decoupled:
|
348 |
+
response_sender = request.get_response_sender()
|
349 |
+
request_id = request.request_id()
|
350 |
+
semantic_token_ids_arr = []
|
351 |
+
max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate)
|
352 |
+
chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate)
|
353 |
+
overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate)
|
354 |
+
self.logger.log_info(
|
355 |
+
f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}"
|
356 |
+
)
|
357 |
+
for generated_ids in generated_ids_iter:
|
358 |
+
if generated_ids is None or len(generated_ids) == 0:
|
359 |
+
break
|
360 |
+
|
361 |
+
semantic_token_ids_arr.append(generated_ids)
|
362 |
+
if len(semantic_token_ids_arr) >= chunk_size:
|
363 |
+
chunk = semantic_token_ids_arr[:chunk_size]
|
364 |
+
generated_semantic_token_ids = np.hstack(chunk)
|
365 |
+
# Process each chunk
|
366 |
+
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
|
367 |
+
# Prepare response to send
|
368 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
369 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
370 |
+
response_sender.send(inference_response)
|
371 |
+
|
372 |
+
semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:]
|
373 |
+
# increase chunk size for better speech quality
|
374 |
+
chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor))
|
375 |
+
self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}")
|
376 |
+
|
377 |
+
if len(semantic_token_ids_arr) > 0: # end to finalize
|
378 |
+
generated_semantic_token_ids = np.hstack(semantic_token_ids_arr)
|
379 |
+
# Process each chunk
|
380 |
+
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
|
381 |
+
# Prepare response to send
|
382 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
383 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
384 |
+
response_sender.send(inference_response)
|
385 |
+
self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}")
|
386 |
+
else:
|
387 |
+
generated_ids = next(generated_ids_iter)
|
388 |
+
if generated_ids is None or len(generated_ids) == 0:
|
389 |
+
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
390 |
+
|
391 |
+
audio = self.token2wav(generated_ids, global_token_ids)
|
392 |
+
|
393 |
+
# Prepare response
|
394 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
395 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
396 |
+
responses.append(inference_response)
|
397 |
+
|
398 |
+
if self.decoupled:
|
399 |
+
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
400 |
+
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
401 |
+
|
402 |
+
if not self.decoupled:
|
403 |
+
return responses
|
404 |
+
|
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "spark_tts"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
model_transaction_policy {
|
22 |
+
decoupled: ${decoupled_mode}
|
23 |
+
}
|
24 |
+
parameters [
|
25 |
+
{
|
26 |
+
key: "llm_tokenizer_dir",
|
27 |
+
value: {string_value:"${llm_tokenizer_dir}"}
|
28 |
+
},
|
29 |
+
{
|
30 |
+
key: "audio_chunk_duration",
|
31 |
+
value: {string_value:"${audio_chunk_duration}"}
|
32 |
+
},
|
33 |
+
{
|
34 |
+
key: "audio_chunk_size_scale_factor",
|
35 |
+
value: {string_value:"${audio_chunk_size_scale_factor}"}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
key: "max_audio_chunk_duration",
|
39 |
+
value: {string_value:"${max_audio_chunk_duration}"}
|
40 |
+
},
|
41 |
+
{
|
42 |
+
key: "audio_chunk_overlap_duration",
|
43 |
+
value: {string_value:"${audio_chunk_overlap_duration}"}
|
44 |
+
},
|
45 |
+
{
|
46 |
+
key: "audio_tokenizer_frame_rate",
|
47 |
+
value: {string_value:"50"}
|
48 |
+
}
|
49 |
+
]
|
50 |
+
|
51 |
+
input [
|
52 |
+
{
|
53 |
+
name: "reference_wav"
|
54 |
+
data_type: TYPE_FP32
|
55 |
+
dims: [-1]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
name: "reference_wav_len"
|
59 |
+
data_type: TYPE_INT32
|
60 |
+
dims: [1]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
name: "reference_text"
|
64 |
+
data_type: TYPE_STRING
|
65 |
+
dims: [1]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
name: "target_text"
|
69 |
+
data_type: TYPE_STRING
|
70 |
+
dims: [1]
|
71 |
+
}
|
72 |
+
]
|
73 |
+
output [
|
74 |
+
{
|
75 |
+
name: "waveform"
|
76 |
+
data_type: TYPE_FP32
|
77 |
+
dims: [ -1 ]
|
78 |
+
}
|
79 |
+
]
|
80 |
+
|
81 |
+
instance_group [
|
82 |
+
{
|
83 |
+
count: ${bls_instance_num}
|
84 |
+
kind: KIND_CPU
|
85 |
+
}
|
86 |
+
]
|
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep
ADDED
File without changes
|
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
name: "tensorrt_llm"
|
28 |
+
backend: "${triton_backend}"
|
29 |
+
max_batch_size: ${triton_max_batch_size}
|
30 |
+
|
31 |
+
model_transaction_policy {
|
32 |
+
decoupled: ${decoupled_mode}
|
33 |
+
}
|
34 |
+
|
35 |
+
dynamic_batching {
|
36 |
+
preferred_batch_size: [ ${triton_max_batch_size} ]
|
37 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
38 |
+
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
39 |
+
}
|
40 |
+
|
41 |
+
input [
|
42 |
+
{
|
43 |
+
name: "input_ids"
|
44 |
+
data_type: TYPE_INT32
|
45 |
+
dims: [ -1 ]
|
46 |
+
allow_ragged_batch: true
|
47 |
+
optional: true
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "encoder_input_features"
|
51 |
+
data_type: ${encoder_input_features_data_type}
|
52 |
+
dims: [ -1, -1 ]
|
53 |
+
allow_ragged_batch: true
|
54 |
+
optional: true
|
55 |
+
},
|
56 |
+
{
|
57 |
+
name: "encoder_output_lengths"
|
58 |
+
data_type: TYPE_INT32
|
59 |
+
dims: [ 1 ]
|
60 |
+
reshape: { shape: [ ] }
|
61 |
+
optional: true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
name: "input_lengths"
|
65 |
+
data_type: TYPE_INT32
|
66 |
+
dims: [ 1 ]
|
67 |
+
reshape: { shape: [ ] }
|
68 |
+
},
|
69 |
+
{
|
70 |
+
name: "request_output_len"
|
71 |
+
data_type: TYPE_INT32
|
72 |
+
dims: [ 1 ]
|
73 |
+
reshape: { shape: [ ] }
|
74 |
+
},
|
75 |
+
{
|
76 |
+
name: "num_return_sequences"
|
77 |
+
data_type: TYPE_INT32
|
78 |
+
dims: [ 1 ]
|
79 |
+
reshape: { shape: [ ] }
|
80 |
+
optional: true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
name: "draft_input_ids"
|
84 |
+
data_type: TYPE_INT32
|
85 |
+
dims: [ -1 ]
|
86 |
+
optional: true
|
87 |
+
allow_ragged_batch: true
|
88 |
+
},
|
89 |
+
{
|
90 |
+
name: "decoder_input_ids"
|
91 |
+
data_type: TYPE_INT32
|
92 |
+
dims: [ -1 ]
|
93 |
+
optional: true
|
94 |
+
allow_ragged_batch: true
|
95 |
+
},
|
96 |
+
{
|
97 |
+
name: "decoder_input_lengths"
|
98 |
+
data_type: TYPE_INT32
|
99 |
+
dims: [ 1 ]
|
100 |
+
optional: true
|
101 |
+
reshape: { shape: [ ] }
|
102 |
+
},
|
103 |
+
{
|
104 |
+
name: "draft_logits"
|
105 |
+
data_type: ${logits_datatype}
|
106 |
+
dims: [ -1, -1 ]
|
107 |
+
optional: true
|
108 |
+
allow_ragged_batch: true
|
109 |
+
},
|
110 |
+
{
|
111 |
+
name: "draft_acceptance_threshold"
|
112 |
+
data_type: TYPE_FP32
|
113 |
+
dims: [ 1 ]
|
114 |
+
reshape: { shape: [ ] }
|
115 |
+
optional: true
|
116 |
+
},
|
117 |
+
{
|
118 |
+
name: "end_id"
|
119 |
+
data_type: TYPE_INT32
|
120 |
+
dims: [ 1 ]
|
121 |
+
reshape: { shape: [ ] }
|
122 |
+
optional: true
|
123 |
+
},
|
124 |
+
{
|
125 |
+
name: "pad_id"
|
126 |
+
data_type: TYPE_INT32
|
127 |
+
dims: [ 1 ]
|
128 |
+
reshape: { shape: [ ] }
|
129 |
+
optional: true
|
130 |
+
},
|
131 |
+
{
|
132 |
+
name: "stop_words_list"
|
133 |
+
data_type: TYPE_INT32
|
134 |
+
dims: [ 2, -1 ]
|
135 |
+
optional: true
|
136 |
+
allow_ragged_batch: true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
name: "bad_words_list"
|
140 |
+
data_type: TYPE_INT32
|
141 |
+
dims: [ 2, -1 ]
|
142 |
+
optional: true
|
143 |
+
allow_ragged_batch: true
|
144 |
+
},
|
145 |
+
{
|
146 |
+
name: "embedding_bias"
|
147 |
+
data_type: TYPE_FP32
|
148 |
+
dims: [ -1 ]
|
149 |
+
optional: true
|
150 |
+
allow_ragged_batch: true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
name: "beam_width"
|
154 |
+
data_type: TYPE_INT32
|
155 |
+
dims: [ 1 ]
|
156 |
+
reshape: { shape: [ ] }
|
157 |
+
optional: true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
name: "temperature"
|
161 |
+
data_type: TYPE_FP32
|
162 |
+
dims: [ 1 ]
|
163 |
+
reshape: { shape: [ ] }
|
164 |
+
optional: true
|
165 |
+
},
|
166 |
+
{
|
167 |
+
name: "runtime_top_k"
|
168 |
+
data_type: TYPE_INT32
|
169 |
+
dims: [ 1 ]
|
170 |
+
reshape: { shape: [ ] }
|
171 |
+
optional: true
|
172 |
+
},
|
173 |
+
{
|
174 |
+
name: "runtime_top_p"
|
175 |
+
data_type: TYPE_FP32
|
176 |
+
dims: [ 1 ]
|
177 |
+
reshape: { shape: [ ] }
|
178 |
+
optional: true
|
179 |
+
},
|
180 |
+
{
|
181 |
+
name: "runtime_top_p_min"
|
182 |
+
data_type: TYPE_FP32
|
183 |
+
dims: [ 1 ]
|
184 |
+
reshape: { shape: [ ] }
|
185 |
+
optional: true
|
186 |
+
},
|
187 |
+
{
|
188 |
+
name: "runtime_top_p_decay"
|
189 |
+
data_type: TYPE_FP32
|
190 |
+
dims: [ 1 ]
|
191 |
+
reshape: { shape: [ ] }
|
192 |
+
optional: true
|
193 |
+
},
|
194 |
+
{
|
195 |
+
name: "runtime_top_p_reset_ids"
|
196 |
+
data_type: TYPE_INT32
|
197 |
+
dims: [ 1 ]
|
198 |
+
reshape: { shape: [ ] }
|
199 |
+
optional: true
|
200 |
+
},
|
201 |
+
{
|
202 |
+
name: "len_penalty"
|
203 |
+
data_type: TYPE_FP32
|
204 |
+
dims: [ 1 ]
|
205 |
+
reshape: { shape: [ ] }
|
206 |
+
optional: true
|
207 |
+
},
|
208 |
+
{
|
209 |
+
name: "early_stopping"
|
210 |
+
data_type: TYPE_BOOL
|
211 |
+
dims: [ 1 ]
|
212 |
+
reshape: { shape: [ ] }
|
213 |
+
optional: true
|
214 |
+
},
|
215 |
+
{
|
216 |
+
name: "repetition_penalty"
|
217 |
+
data_type: TYPE_FP32
|
218 |
+
dims: [ 1 ]
|
219 |
+
reshape: { shape: [ ] }
|
220 |
+
optional: true
|
221 |
+
},
|
222 |
+
{
|
223 |
+
name: "min_length"
|
224 |
+
data_type: TYPE_INT32
|
225 |
+
dims: [ 1 ]
|
226 |
+
reshape: { shape: [ ] }
|
227 |
+
optional: true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
name: "beam_search_diversity_rate"
|
231 |
+
data_type: TYPE_FP32
|
232 |
+
dims: [ 1 ]
|
233 |
+
reshape: { shape: [ ] }
|
234 |
+
optional: true
|
235 |
+
},
|
236 |
+
{
|
237 |
+
name: "presence_penalty"
|
238 |
+
data_type: TYPE_FP32
|
239 |
+
dims: [ 1 ]
|
240 |
+
reshape: { shape: [ ] }
|
241 |
+
optional: true
|
242 |
+
},
|
243 |
+
{
|
244 |
+
name: "frequency_penalty"
|
245 |
+
data_type: TYPE_FP32
|
246 |
+
dims: [ 1 ]
|
247 |
+
reshape: { shape: [ ] }
|
248 |
+
optional: true
|
249 |
+
},
|
250 |
+
{
|
251 |
+
name: "random_seed"
|
252 |
+
data_type: TYPE_UINT64
|
253 |
+
dims: [ 1 ]
|
254 |
+
reshape: { shape: [ ] }
|
255 |
+
optional: true
|
256 |
+
},
|
257 |
+
{
|
258 |
+
name: "return_log_probs"
|
259 |
+
data_type: TYPE_BOOL
|
260 |
+
dims: [ 1 ]
|
261 |
+
reshape: { shape: [ ] }
|
262 |
+
optional: true
|
263 |
+
},
|
264 |
+
{
|
265 |
+
name: "return_context_logits"
|
266 |
+
data_type: TYPE_BOOL
|
267 |
+
dims: [ 1 ]
|
268 |
+
reshape: { shape: [ ] }
|
269 |
+
optional: true
|
270 |
+
},
|
271 |
+
{
|
272 |
+
name: "return_generation_logits"
|
273 |
+
data_type: TYPE_BOOL
|
274 |
+
dims: [ 1 ]
|
275 |
+
reshape: { shape: [ ] }
|
276 |
+
optional: true
|
277 |
+
},
|
278 |
+
{
|
279 |
+
name: "return_perf_metrics"
|
280 |
+
data_type: TYPE_BOOL
|
281 |
+
dims: [ 1 ]
|
282 |
+
reshape: { shape: [ ] }
|
283 |
+
optional: true
|
284 |
+
},
|
285 |
+
{
|
286 |
+
name: "exclude_input_in_output"
|
287 |
+
data_type: TYPE_BOOL
|
288 |
+
dims: [ 1 ]
|
289 |
+
reshape: { shape: [ ] }
|
290 |
+
optional: true
|
291 |
+
},
|
292 |
+
{
|
293 |
+
name: "stop"
|
294 |
+
data_type: TYPE_BOOL
|
295 |
+
dims: [ 1 ]
|
296 |
+
reshape: { shape: [ ] }
|
297 |
+
optional: true
|
298 |
+
},
|
299 |
+
{
|
300 |
+
name: "streaming"
|
301 |
+
data_type: TYPE_BOOL
|
302 |
+
dims: [ 1 ]
|
303 |
+
reshape: { shape: [ ] }
|
304 |
+
optional: true
|
305 |
+
},
|
306 |
+
{
|
307 |
+
name: "prompt_embedding_table"
|
308 |
+
data_type: TYPE_FP16
|
309 |
+
dims: [ -1, -1 ]
|
310 |
+
optional: true
|
311 |
+
allow_ragged_batch: true
|
312 |
+
},
|
313 |
+
{
|
314 |
+
name: "prompt_table_extra_ids"
|
315 |
+
data_type: TYPE_UINT64
|
316 |
+
dims: [ -1 ]
|
317 |
+
optional: true
|
318 |
+
allow_ragged_batch: true
|
319 |
+
},
|
320 |
+
{
|
321 |
+
name: "prompt_vocab_size"
|
322 |
+
data_type: TYPE_INT32
|
323 |
+
dims: [ 1 ]
|
324 |
+
reshape: { shape: [ ] }
|
325 |
+
optional: true
|
326 |
+
},
|
327 |
+
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
|
328 |
+
{
|
329 |
+
name: "cross_attention_mask"
|
330 |
+
data_type: TYPE_BOOL
|
331 |
+
dims: [ -1, -1 ]
|
332 |
+
optional: true
|
333 |
+
allow_ragged_batch: true
|
334 |
+
},
|
335 |
+
# Mrope param when mrope is used
|
336 |
+
{
|
337 |
+
name: "mrope_rotary_cos_sin"
|
338 |
+
data_type: TYPE_FP32
|
339 |
+
dims: [ -1 ]
|
340 |
+
optional: true
|
341 |
+
},
|
342 |
+
{
|
343 |
+
name: "mrope_position_deltas"
|
344 |
+
data_type: TYPE_INT64
|
345 |
+
dims: [ 1 ]
|
346 |
+
optional: true
|
347 |
+
},
|
348 |
+
# the unique task ID for the given LoRA.
|
349 |
+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
350 |
+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
351 |
+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
352 |
+
{
|
353 |
+
name: "lora_task_id"
|
354 |
+
data_type: TYPE_UINT64
|
355 |
+
dims: [ 1 ]
|
356 |
+
reshape: { shape: [ ] }
|
357 |
+
optional: true
|
358 |
+
},
|
359 |
+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
360 |
+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
361 |
+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
362 |
+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
363 |
+
{
|
364 |
+
name: "lora_weights"
|
365 |
+
data_type: TYPE_FP16
|
366 |
+
dims: [ -1, -1 ]
|
367 |
+
optional: true
|
368 |
+
allow_ragged_batch: true
|
369 |
+
},
|
370 |
+
# module identifier (same size a first dimension of lora_weights)
|
371 |
+
# See LoraModule::ModuleType for model id mapping
|
372 |
+
#
|
373 |
+
# "attn_qkv": 0 # compbined qkv adapter
|
374 |
+
# "attn_q": 1 # q adapter
|
375 |
+
# "attn_k": 2 # k adapter
|
376 |
+
# "attn_v": 3 # v adapter
|
377 |
+
# "attn_dense": 4 # adapter for the dense layer in attention
|
378 |
+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
379 |
+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
380 |
+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
381 |
+
#
|
382 |
+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
383 |
+
{
|
384 |
+
name: "lora_config"
|
385 |
+
data_type: TYPE_INT32
|
386 |
+
dims: [ -1, 3 ]
|
387 |
+
optional: true
|
388 |
+
allow_ragged_batch: true
|
389 |
+
},
|
390 |
+
{
|
391 |
+
name: "context_phase_params"
|
392 |
+
data_type: TYPE_UINT8
|
393 |
+
dims: [ -1 ]
|
394 |
+
optional: true
|
395 |
+
allow_ragged_batch: true
|
396 |
+
},
|
397 |
+
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
|
398 |
+
{
|
399 |
+
name: "skip_cross_attn_blocks"
|
400 |
+
data_type: TYPE_BOOL
|
401 |
+
dims: [ 1 ]
|
402 |
+
optional: true
|
403 |
+
allow_ragged_batch: true
|
404 |
+
},
|
405 |
+
{
|
406 |
+
name: "retention_token_range_starts"
|
407 |
+
data_type: TYPE_INT32
|
408 |
+
dims: [ -1 ]
|
409 |
+
optional: true
|
410 |
+
allow_ragged_batch: true
|
411 |
+
},
|
412 |
+
{
|
413 |
+
name: "retention_token_range_ends"
|
414 |
+
data_type: TYPE_INT32
|
415 |
+
dims: [ -1 ]
|
416 |
+
optional: true
|
417 |
+
allow_ragged_batch: true
|
418 |
+
},
|
419 |
+
{
|
420 |
+
name: "retention_token_range_priorities"
|
421 |
+
data_type: TYPE_INT32
|
422 |
+
dims: [ -1 ]
|
423 |
+
optional: true
|
424 |
+
allow_ragged_batch: true
|
425 |
+
},
|
426 |
+
{
|
427 |
+
name: "retention_token_range_durations_ms"
|
428 |
+
data_type: TYPE_INT32
|
429 |
+
dims: [ -1 ]
|
430 |
+
optional: true
|
431 |
+
allow_ragged_batch: true
|
432 |
+
},
|
433 |
+
{
|
434 |
+
name: "retention_decode_priority"
|
435 |
+
data_type: TYPE_INT32
|
436 |
+
dims: [ 1 ]
|
437 |
+
optional: true
|
438 |
+
allow_ragged_batch: true
|
439 |
+
},
|
440 |
+
{
|
441 |
+
name: "retention_decode_duration_ms"
|
442 |
+
data_type: TYPE_INT32
|
443 |
+
dims: [ 1 ]
|
444 |
+
optional: true
|
445 |
+
allow_ragged_batch: true
|
446 |
+
},
|
447 |
+
{
|
448 |
+
name: "guided_decoding_guide_type"
|
449 |
+
data_type: TYPE_STRING
|
450 |
+
dims: [ 1 ]
|
451 |
+
optional: true
|
452 |
+
allow_ragged_batch: true
|
453 |
+
},
|
454 |
+
{
|
455 |
+
name: "guided_decoding_guide"
|
456 |
+
data_type: TYPE_STRING
|
457 |
+
dims: [ 1 ]
|
458 |
+
optional: true
|
459 |
+
allow_ragged_batch: true
|
460 |
+
},
|
461 |
+
{
|
462 |
+
name: "lookahead_window_size"
|
463 |
+
data_type: TYPE_INT32
|
464 |
+
dims: [ 1 ]
|
465 |
+
optional: true
|
466 |
+
allow_ragged_batch: true
|
467 |
+
},
|
468 |
+
{
|
469 |
+
name: "lookahead_ngram_size"
|
470 |
+
data_type: TYPE_INT32
|
471 |
+
dims: [ 1 ]
|
472 |
+
optional: true
|
473 |
+
allow_ragged_batch: true
|
474 |
+
},
|
475 |
+
{
|
476 |
+
name: "lookahead_verification_set_size"
|
477 |
+
data_type: TYPE_INT32
|
478 |
+
dims: [ 1 ]
|
479 |
+
optional: true
|
480 |
+
allow_ragged_batch: true
|
481 |
+
}
|
482 |
+
]
|
483 |
+
output [
|
484 |
+
{
|
485 |
+
name: "output_ids"
|
486 |
+
data_type: TYPE_INT32
|
487 |
+
dims: [ -1, -1 ]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
name: "sequence_length"
|
491 |
+
data_type: TYPE_INT32
|
492 |
+
dims: [ -1 ]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
name: "cum_log_probs"
|
496 |
+
data_type: TYPE_FP32
|
497 |
+
dims: [ -1 ]
|
498 |
+
},
|
499 |
+
{
|
500 |
+
name: "output_log_probs"
|
501 |
+
data_type: TYPE_FP32
|
502 |
+
dims: [ -1, -1 ]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
name: "context_logits"
|
506 |
+
data_type: ${logits_datatype}
|
507 |
+
dims: [ -1, -1 ]
|
508 |
+
},
|
509 |
+
{
|
510 |
+
name: "generation_logits"
|
511 |
+
data_type: ${logits_datatype}
|
512 |
+
dims: [ -1, -1, -1 ]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
name: "batch_index"
|
516 |
+
data_type: TYPE_INT32
|
517 |
+
dims: [ 1 ]
|
518 |
+
},
|
519 |
+
{
|
520 |
+
name: "sequence_index"
|
521 |
+
data_type: TYPE_INT32
|
522 |
+
dims: [ 1 ]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
name: "context_phase_params"
|
526 |
+
data_type: TYPE_UINT8
|
527 |
+
dims: [ -1 ]
|
528 |
+
},
|
529 |
+
{
|
530 |
+
name: "kv_cache_alloc_new_blocks"
|
531 |
+
data_type: TYPE_INT32
|
532 |
+
dims: [ 1 ]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
name: "kv_cache_reused_blocks"
|
536 |
+
data_type: TYPE_INT32
|
537 |
+
dims: [ 1 ]
|
538 |
+
},
|
539 |
+
{
|
540 |
+
name: "kv_cache_alloc_total_blocks"
|
541 |
+
data_type: TYPE_INT32
|
542 |
+
dims: [ 1 ]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
name: "arrival_time_ns"
|
546 |
+
data_type: TYPE_INT64
|
547 |
+
dims: [ 1 ]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
name: "first_scheduled_time_ns"
|
551 |
+
data_type: TYPE_INT64
|
552 |
+
dims: [ 1 ]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
name: "first_token_time_ns"
|
556 |
+
data_type: TYPE_INT64
|
557 |
+
dims: [ 1 ]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
name: "last_token_time_ns"
|
561 |
+
data_type: TYPE_INT64
|
562 |
+
dims: [ 1 ]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
name: "acceptance_rate"
|
566 |
+
data_type: TYPE_FP32
|
567 |
+
dims: [ 1 ]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
name: "total_accepted_draft_tokens"
|
571 |
+
data_type: TYPE_INT32
|
572 |
+
dims: [ 1 ]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
name: "total_draft_tokens"
|
576 |
+
data_type: TYPE_INT32
|
577 |
+
dims: [ 1 ]
|
578 |
+
}
|
579 |
+
]
|
580 |
+
instance_group [
|
581 |
+
{
|
582 |
+
count: 1
|
583 |
+
kind : KIND_CPU
|
584 |
+
}
|
585 |
+
]
|
586 |
+
parameters: {
|
587 |
+
key: "max_beam_width"
|
588 |
+
value: {
|
589 |
+
string_value: "${max_beam_width}"
|
590 |
+
}
|
591 |
+
}
|
592 |
+
parameters: {
|
593 |
+
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
594 |
+
value: {
|
595 |
+
string_value: "no"
|
596 |
+
}
|
597 |
+
}
|
598 |
+
parameters: {
|
599 |
+
key: "gpt_model_type"
|
600 |
+
value: {
|
601 |
+
string_value: "${batching_strategy}"
|
602 |
+
}
|
603 |
+
}
|
604 |
+
parameters: {
|
605 |
+
key: "gpt_model_path"
|
606 |
+
value: {
|
607 |
+
string_value: "${engine_dir}"
|
608 |
+
}
|
609 |
+
}
|
610 |
+
parameters: {
|
611 |
+
key: "encoder_model_path"
|
612 |
+
value: {
|
613 |
+
string_value: "${encoder_engine_dir}"
|
614 |
+
}
|
615 |
+
}
|
616 |
+
parameters: {
|
617 |
+
key: "max_tokens_in_paged_kv_cache"
|
618 |
+
value: {
|
619 |
+
string_value: "${max_tokens_in_paged_kv_cache}"
|
620 |
+
}
|
621 |
+
}
|
622 |
+
parameters: {
|
623 |
+
key: "max_attention_window_size"
|
624 |
+
value: {
|
625 |
+
string_value: "${max_attention_window_size}"
|
626 |
+
}
|
627 |
+
}
|
628 |
+
parameters: {
|
629 |
+
key: "sink_token_length"
|
630 |
+
value: {
|
631 |
+
string_value: "${sink_token_length}"
|
632 |
+
}
|
633 |
+
}
|
634 |
+
parameters: {
|
635 |
+
key: "batch_scheduler_policy"
|
636 |
+
value: {
|
637 |
+
string_value: "${batch_scheduler_policy}"
|
638 |
+
}
|
639 |
+
}
|
640 |
+
parameters: {
|
641 |
+
key: "kv_cache_free_gpu_mem_fraction"
|
642 |
+
value: {
|
643 |
+
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
644 |
+
}
|
645 |
+
}
|
646 |
+
parameters: {
|
647 |
+
key: "cross_kv_cache_fraction"
|
648 |
+
value: {
|
649 |
+
string_value: "${cross_kv_cache_fraction}"
|
650 |
+
}
|
651 |
+
}
|
652 |
+
parameters: {
|
653 |
+
key: "kv_cache_host_memory_bytes"
|
654 |
+
value: {
|
655 |
+
string_value: "${kv_cache_host_memory_bytes}"
|
656 |
+
}
|
657 |
+
}
|
658 |
+
# kv_cache_onboard_blocks is for internal implementation.
|
659 |
+
parameters: {
|
660 |
+
key: "kv_cache_onboard_blocks"
|
661 |
+
value: {
|
662 |
+
string_value: "${kv_cache_onboard_blocks}"
|
663 |
+
}
|
664 |
+
}
|
665 |
+
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
666 |
+
# parameters: {
|
667 |
+
# key: "enable_trt_overlap"
|
668 |
+
# value: {
|
669 |
+
# string_value: "${enable_trt_overlap}"
|
670 |
+
# }
|
671 |
+
# }
|
672 |
+
parameters: {
|
673 |
+
key: "exclude_input_in_output"
|
674 |
+
value: {
|
675 |
+
string_value: "${exclude_input_in_output}"
|
676 |
+
}
|
677 |
+
}
|
678 |
+
parameters: {
|
679 |
+
key: "cancellation_check_period_ms"
|
680 |
+
value: {
|
681 |
+
string_value: "${cancellation_check_period_ms}"
|
682 |
+
}
|
683 |
+
}
|
684 |
+
parameters: {
|
685 |
+
key: "stats_check_period_ms"
|
686 |
+
value: {
|
687 |
+
string_value: "${stats_check_period_ms}"
|
688 |
+
}
|
689 |
+
}
|
690 |
+
parameters: {
|
691 |
+
key: "iter_stats_max_iterations"
|
692 |
+
value: {
|
693 |
+
string_value: "${iter_stats_max_iterations}"
|
694 |
+
}
|
695 |
+
}
|
696 |
+
parameters: {
|
697 |
+
key: "request_stats_max_iterations"
|
698 |
+
value: {
|
699 |
+
string_value: "${request_stats_max_iterations}"
|
700 |
+
}
|
701 |
+
}
|
702 |
+
parameters: {
|
703 |
+
key: "enable_kv_cache_reuse"
|
704 |
+
value: {
|
705 |
+
string_value: "${enable_kv_cache_reuse}"
|
706 |
+
}
|
707 |
+
}
|
708 |
+
parameters: {
|
709 |
+
key: "normalize_log_probs"
|
710 |
+
value: {
|
711 |
+
string_value: "${normalize_log_probs}"
|
712 |
+
}
|
713 |
+
}
|
714 |
+
parameters: {
|
715 |
+
key: "enable_chunked_context"
|
716 |
+
value: {
|
717 |
+
string_value: "${enable_chunked_context}"
|
718 |
+
}
|
719 |
+
}
|
720 |
+
parameters: {
|
721 |
+
key: "gpu_device_ids"
|
722 |
+
value: {
|
723 |
+
string_value: "${gpu_device_ids}"
|
724 |
+
}
|
725 |
+
}
|
726 |
+
parameters: {
|
727 |
+
key: "participant_ids"
|
728 |
+
value: {
|
729 |
+
string_value: "${participant_ids}"
|
730 |
+
}
|
731 |
+
}
|
732 |
+
parameters: {
|
733 |
+
key: "lora_cache_optimal_adapter_size"
|
734 |
+
value: {
|
735 |
+
string_value: "${lora_cache_optimal_adapter_size}"
|
736 |
+
}
|
737 |
+
}
|
738 |
+
parameters: {
|
739 |
+
key: "lora_cache_max_adapter_size"
|
740 |
+
value: {
|
741 |
+
string_value: "${lora_cache_max_adapter_size}"
|
742 |
+
}
|
743 |
+
}
|
744 |
+
parameters: {
|
745 |
+
key: "lora_cache_gpu_memory_fraction"
|
746 |
+
value: {
|
747 |
+
string_value: "${lora_cache_gpu_memory_fraction}"
|
748 |
+
}
|
749 |
+
}
|
750 |
+
parameters: {
|
751 |
+
key: "lora_cache_host_memory_bytes"
|
752 |
+
value: {
|
753 |
+
string_value: "${lora_cache_host_memory_bytes}"
|
754 |
+
}
|
755 |
+
}
|
756 |
+
parameters: {
|
757 |
+
key: "lora_prefetch_dir"
|
758 |
+
value: {
|
759 |
+
string_value: "${lora_prefetch_dir}"
|
760 |
+
}
|
761 |
+
}
|
762 |
+
parameters: {
|
763 |
+
key: "decoding_mode"
|
764 |
+
value: {
|
765 |
+
string_value: "${decoding_mode}"
|
766 |
+
}
|
767 |
+
}
|
768 |
+
parameters: {
|
769 |
+
key: "executor_worker_path"
|
770 |
+
value: {
|
771 |
+
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
772 |
+
}
|
773 |
+
}
|
774 |
+
parameters: {
|
775 |
+
key: "lookahead_window_size"
|
776 |
+
value: {
|
777 |
+
string_value: "${lookahead_window_size}"
|
778 |
+
}
|
779 |
+
}
|
780 |
+
parameters: {
|
781 |
+
key: "lookahead_ngram_size"
|
782 |
+
value: {
|
783 |
+
string_value: "${lookahead_ngram_size}"
|
784 |
+
}
|
785 |
+
}
|
786 |
+
parameters: {
|
787 |
+
key: "lookahead_verification_set_size"
|
788 |
+
value: {
|
789 |
+
string_value: "${lookahead_verification_set_size}"
|
790 |
+
}
|
791 |
+
}
|
792 |
+
parameters: {
|
793 |
+
key: "medusa_choices"
|
794 |
+
value: {
|
795 |
+
string_value: "${medusa_choices}"
|
796 |
+
}
|
797 |
+
}
|
798 |
+
parameters: {
|
799 |
+
key: "eagle_choices"
|
800 |
+
value: {
|
801 |
+
string_value: "${eagle_choices}"
|
802 |
+
}
|
803 |
+
}
|
804 |
+
parameters: {
|
805 |
+
key: "gpu_weights_percent"
|
806 |
+
value: {
|
807 |
+
string_value: "${gpu_weights_percent}"
|
808 |
+
}
|
809 |
+
}
|
810 |
+
parameters: {
|
811 |
+
key: "enable_context_fmha_fp32_acc"
|
812 |
+
value: {
|
813 |
+
string_value: "${enable_context_fmha_fp32_acc}"
|
814 |
+
}
|
815 |
+
}
|
816 |
+
parameters: {
|
817 |
+
key: "multi_block_mode"
|
818 |
+
value: {
|
819 |
+
string_value: "${multi_block_mode}"
|
820 |
+
}
|
821 |
+
}
|
822 |
+
parameters: {
|
823 |
+
key: "cuda_graph_mode"
|
824 |
+
value: {
|
825 |
+
string_value: "${cuda_graph_mode}"
|
826 |
+
}
|
827 |
+
}
|
828 |
+
parameters: {
|
829 |
+
key: "cuda_graph_cache_size"
|
830 |
+
value: {
|
831 |
+
string_value: "${cuda_graph_cache_size}"
|
832 |
+
}
|
833 |
+
}
|
834 |
+
parameters: {
|
835 |
+
key: "speculative_decoding_fast_logits"
|
836 |
+
value: {
|
837 |
+
string_value: "${speculative_decoding_fast_logits}"
|
838 |
+
}
|
839 |
+
}
|
840 |
+
parameters: {
|
841 |
+
key: "tokenizer_dir"
|
842 |
+
value: {
|
843 |
+
string_value: "${tokenizer_dir}"
|
844 |
+
}
|
845 |
+
}
|
846 |
+
parameters: {
|
847 |
+
key: "guided_decoding_backend"
|
848 |
+
value: {
|
849 |
+
string_value: "${guided_decoding_backend}"
|
850 |
+
}
|
851 |
+
}
|
852 |
+
parameters: {
|
853 |
+
key: "xgrammar_tokenizer_info_path"
|
854 |
+
value: {
|
855 |
+
string_value: "${xgrammar_tokenizer_info_path}"
|
856 |
+
}
|
857 |
+
}
|
runtime/triton_trtllm/model_repo/vocoder/1/model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import os
|
29 |
+
import logging
|
30 |
+
from typing import List, Dict
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from torch.utils.dlpack import to_dlpack
|
34 |
+
|
35 |
+
import triton_python_backend_utils as pb_utils
|
36 |
+
|
37 |
+
from sparktts.models.bicodec import BiCodec
|
38 |
+
|
39 |
+
# Configure logging
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
class TritonPythonModel:
|
44 |
+
"""Triton Python model for vocoder.
|
45 |
+
|
46 |
+
This model takes global and semantic tokens as input and generates audio waveforms
|
47 |
+
using the BiCodec vocoder.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def initialize(self, args):
|
51 |
+
"""Initialize the model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
args: Dictionary containing model configuration
|
55 |
+
"""
|
56 |
+
# Parse model parameters
|
57 |
+
parameters = json.loads(args['model_config'])['parameters']
|
58 |
+
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
59 |
+
model_dir = model_params["model_dir"]
|
60 |
+
|
61 |
+
# Initialize device and vocoder
|
62 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
+
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
64 |
+
|
65 |
+
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
|
66 |
+
del self.vocoder.encoder, self.vocoder.postnet
|
67 |
+
self.vocoder.eval().to(self.device) # Set model to evaluation mode
|
68 |
+
|
69 |
+
logger.info("Vocoder initialized successfully")
|
70 |
+
|
71 |
+
|
72 |
+
def execute(self, requests):
|
73 |
+
"""Execute inference on the batched requests.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
requests: List of inference requests
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List of inference responses containing generated waveforms
|
80 |
+
"""
|
81 |
+
global_tokens_list, semantic_tokens_list = [], []
|
82 |
+
|
83 |
+
# Process each request in batch
|
84 |
+
for request in requests:
|
85 |
+
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
86 |
+
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
87 |
+
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
88 |
+
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
89 |
+
|
90 |
+
# Concatenate tokens for batch processing
|
91 |
+
global_tokens = torch.cat(global_tokens_list, dim=0)
|
92 |
+
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
93 |
+
|
94 |
+
|
95 |
+
# Generate waveforms
|
96 |
+
with torch.no_grad():
|
97 |
+
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
|
98 |
+
|
99 |
+
# Prepare responses
|
100 |
+
responses = []
|
101 |
+
for i in range(len(requests)):
|
102 |
+
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
103 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
104 |
+
responses.append(inference_response)
|
105 |
+
|
106 |
+
return responses
|
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "vocoder"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "global_tokens"
|
31 |
+
data_type: TYPE_INT32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "semantic_tokens"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [-1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "waveform"
|
43 |
+
data_type: TYPE_FP32
|
44 |
+
dims: [ -1 ]
|
45 |
+
}
|
46 |
+
]
|
47 |
+
|
48 |
+
instance_group [
|
49 |
+
{
|
50 |
+
count: 1
|
51 |
+
kind: KIND_CPU
|
52 |
+
}
|
53 |
+
]
|
runtime/triton_trtllm/run.sh
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=../../../Spark-TTS/
|
2 |
+
export CUDA_VISIBLE_DEVICES=0
|
3 |
+
stage=$1
|
4 |
+
stop_stage=$2
|
5 |
+
service_type=$3
|
6 |
+
echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type"
|
7 |
+
|
8 |
+
huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
|
9 |
+
trt_dtype=bfloat16
|
10 |
+
trt_weights_dir=./tllm_checkpoint_${trt_dtype}
|
11 |
+
trt_engines_dir=./trt_engines_${trt_dtype}
|
12 |
+
|
13 |
+
model_repo=./model_repo_test
|
14 |
+
|
15 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
16 |
+
echo "Downloading Spark-TTS-0.5B from HuggingFace"
|
17 |
+
huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
|
18 |
+
fi
|
19 |
+
|
20 |
+
|
21 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
22 |
+
echo "Converting checkpoint to TensorRT weights"
|
23 |
+
python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
|
24 |
+
--output_dir $trt_weights_dir \
|
25 |
+
--dtype $trt_dtype || exit 1
|
26 |
+
|
27 |
+
echo "Building TensorRT engines"
|
28 |
+
trtllm-build --checkpoint_dir $trt_weights_dir \
|
29 |
+
--output_dir $trt_engines_dir \
|
30 |
+
--max_batch_size 16 \
|
31 |
+
--max_num_tokens 32768 \
|
32 |
+
--gemm_plugin $trt_dtype || exit 1
|
33 |
+
fi
|
34 |
+
|
35 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
36 |
+
echo "Creating model repository"
|
37 |
+
rm -rf $model_repo
|
38 |
+
mkdir -p $model_repo
|
39 |
+
spark_tts_dir="spark_tts"
|
40 |
+
|
41 |
+
cp -r ./model_repo/${spark_tts_dir} $model_repo
|
42 |
+
cp -r ./model_repo/audio_tokenizer $model_repo
|
43 |
+
cp -r ./model_repo/tensorrt_llm $model_repo
|
44 |
+
cp -r ./model_repo/vocoder $model_repo
|
45 |
+
|
46 |
+
ENGINE_PATH=$trt_engines_dir
|
47 |
+
MAX_QUEUE_DELAY_MICROSECONDS=0
|
48 |
+
MODEL_DIR=$huggingface_model_local_dir
|
49 |
+
LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
|
50 |
+
BLS_INSTANCE_NUM=4
|
51 |
+
TRITON_MAX_BATCH_SIZE=16
|
52 |
+
# streaming TTS parameters
|
53 |
+
AUDIO_CHUNK_DURATION=1.0
|
54 |
+
MAX_AUDIO_CHUNK_DURATION=30.0
|
55 |
+
AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0
|
56 |
+
AUDIO_CHUNK_OVERLAP_DURATION=0.1
|
57 |
+
python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
58 |
+
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
59 |
+
if [ "$service_type" == "streaming" ]; then
|
60 |
+
DECOUPLED_MODE=True
|
61 |
+
else
|
62 |
+
DECOUPLED_MODE=False
|
63 |
+
fi
|
64 |
+
python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION}
|
65 |
+
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
66 |
+
|
67 |
+
fi
|
68 |
+
|
69 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
70 |
+
echo "Starting Triton server"
|
71 |
+
tritonserver --model-repository ${model_repo}
|
72 |
+
fi
|
73 |
+
|
74 |
+
|
75 |
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
76 |
+
echo "Running benchmark client"
|
77 |
+
num_task=2
|
78 |
+
if [ "$service_type" == "streaming" ]; then
|
79 |
+
mode="streaming"
|
80 |
+
else
|
81 |
+
mode="offline"
|
82 |
+
fi
|
83 |
+
python3 client_grpc.py \
|
84 |
+
--server-addr localhost \
|
85 |
+
--model-name spark_tts \
|
86 |
+
--num-tasks $num_task \
|
87 |
+
--mode $mode \
|
88 |
+
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_new
|
89 |
+
fi
|
90 |
+
|
91 |
+
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
92 |
+
echo "Running single utterance client"
|
93 |
+
if [ "$service_type" == "streaming" ]; then
|
94 |
+
python client_grpc.py \
|
95 |
+
--server-addr localhost \
|
96 |
+
--reference-audio ../../example/prompt_audio.wav \
|
97 |
+
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
98 |
+
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
99 |
+
--model-name spark_tts \
|
100 |
+
--chunk-overlap-duration 0.1 \
|
101 |
+
--mode streaming
|
102 |
+
else
|
103 |
+
python client_http.py \
|
104 |
+
--reference-audio ../../example/prompt_audio.wav \
|
105 |
+
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
106 |
+
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
107 |
+
--model-name spark_tts
|
108 |
+
fi
|
109 |
+
fi
|
runtime/triton_trtllm/scripts/convert_checkpoint.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
|
7 |
+
from transformers import AutoConfig
|
8 |
+
|
9 |
+
import tensorrt_llm
|
10 |
+
from tensorrt_llm._utils import release_gc
|
11 |
+
from tensorrt_llm.logger import logger
|
12 |
+
from tensorrt_llm.mapping import Mapping
|
13 |
+
from tensorrt_llm.models import QWenForCausalLM
|
14 |
+
from tensorrt_llm.models.modeling_utils import QuantConfig
|
15 |
+
from tensorrt_llm.quantization import QuantAlgo
|
16 |
+
|
17 |
+
|
18 |
+
def parse_arguments():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--model_dir', type=str, default=None, required=True)
|
21 |
+
parser.add_argument('--tp_size',
|
22 |
+
type=int,
|
23 |
+
default=1,
|
24 |
+
help='N-way tensor parallelism size')
|
25 |
+
parser.add_argument('--pp_size',
|
26 |
+
type=int,
|
27 |
+
default=1,
|
28 |
+
help='N-way pipeline parallelism size')
|
29 |
+
parser.add_argument(
|
30 |
+
'--dtype',
|
31 |
+
type=str,
|
32 |
+
default='auto',
|
33 |
+
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
34 |
+
help=
|
35 |
+
"The data type for the model weights and activations if not quantized. "
|
36 |
+
"If 'auto', the data type is automatically inferred from the source model; "
|
37 |
+
"however, if the source dtype is float32, it is converted to float16.")
|
38 |
+
parser.add_argument(
|
39 |
+
'--use_weight_only',
|
40 |
+
default=False,
|
41 |
+
action="store_true",
|
42 |
+
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
43 |
+
'See --weight_only_precision to set the precision')
|
44 |
+
parser.add_argument(
|
45 |
+
'--disable_weight_only_quant_plugin',
|
46 |
+
default=False,
|
47 |
+
action="store_true",
|
48 |
+
help=
|
49 |
+
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
50 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
'--weight_only_precision',
|
54 |
+
const='int8',
|
55 |
+
type=str,
|
56 |
+
nargs='?',
|
57 |
+
default='int8',
|
58 |
+
choices=['int8', 'int4', 'int4_gptq'],
|
59 |
+
help=
|
60 |
+
'Define the precision for the weights when using weight-only quantization.'
|
61 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
'--calib_dataset',
|
65 |
+
type=str,
|
66 |
+
default='ccdv/cnn_dailymail',
|
67 |
+
help=
|
68 |
+
"The huggingface dataset name or the local directory of the dataset for calibration."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--smoothquant",
|
72 |
+
"-sq",
|
73 |
+
type=float,
|
74 |
+
default=None,
|
75 |
+
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
76 |
+
" to Smoothquant the model, and output int8 weights."
|
77 |
+
" A good first try is 0.5. Must be in [0, 1]")
|
78 |
+
parser.add_argument(
|
79 |
+
'--per_channel',
|
80 |
+
action="store_true",
|
81 |
+
default=False,
|
82 |
+
help=
|
83 |
+
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
84 |
+
'per_channel instead uses a different static scaling factor for each channel. '
|
85 |
+
'The latter is usually more accurate, but a little slower.')
|
86 |
+
parser.add_argument(
|
87 |
+
'--per_token',
|
88 |
+
action="store_true",
|
89 |
+
default=False,
|
90 |
+
help=
|
91 |
+
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
92 |
+
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
93 |
+
'The latter is usually more accurate, but a little slower.')
|
94 |
+
parser.add_argument(
|
95 |
+
'--int8_kv_cache',
|
96 |
+
default=False,
|
97 |
+
action="store_true",
|
98 |
+
help=
|
99 |
+
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
'--per_group',
|
103 |
+
default=False,
|
104 |
+
action="store_true",
|
105 |
+
help=
|
106 |
+
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
107 |
+
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
108 |
+
'The flag is built for GPTQ/AWQ quantization.')
|
109 |
+
|
110 |
+
parser.add_argument('--group_size',
|
111 |
+
type=int,
|
112 |
+
default=128,
|
113 |
+
help='Group size used in GPTQ quantization.')
|
114 |
+
|
115 |
+
parser.add_argument("--load_model_on_cpu", action="store_true")
|
116 |
+
parser.add_argument(
|
117 |
+
'--use_parallel_embedding',
|
118 |
+
action="store_true",
|
119 |
+
default=False,
|
120 |
+
help=
|
121 |
+
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
'--embedding_sharding_dim',
|
125 |
+
type=int,
|
126 |
+
default=0,
|
127 |
+
choices=[0, 1],
|
128 |
+
help=
|
129 |
+
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
130 |
+
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
131 |
+
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
132 |
+
)
|
133 |
+
parser.add_argument('--output_dir',
|
134 |
+
type=str,
|
135 |
+
default='tllm_checkpoint',
|
136 |
+
help='The path to save the TensorRT-LLM checkpoint')
|
137 |
+
parser.add_argument(
|
138 |
+
'--workers',
|
139 |
+
type=int,
|
140 |
+
default=1,
|
141 |
+
help='The number of workers for converting checkpoint in parallel')
|
142 |
+
parser.add_argument(
|
143 |
+
'--moe_tp_size',
|
144 |
+
type=int,
|
145 |
+
default=-1,
|
146 |
+
help=
|
147 |
+
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
'--moe_ep_size',
|
151 |
+
type=int,
|
152 |
+
default=-1,
|
153 |
+
help=
|
154 |
+
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
155 |
+
)
|
156 |
+
args = parser.parse_args()
|
157 |
+
return args
|
158 |
+
|
159 |
+
|
160 |
+
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
161 |
+
'''return config dict with quantization info based on the command line args
|
162 |
+
'''
|
163 |
+
quant_config = QuantConfig()
|
164 |
+
if args.use_weight_only:
|
165 |
+
if args.weight_only_precision == 'int8':
|
166 |
+
quant_config.quant_algo = QuantAlgo.W8A16
|
167 |
+
elif args.weight_only_precision == 'int4':
|
168 |
+
quant_config.quant_algo = QuantAlgo.W4A16
|
169 |
+
elif args.smoothquant:
|
170 |
+
quant_config.smoothquant_val = args.smoothquant
|
171 |
+
if args.per_channel:
|
172 |
+
if args.per_token:
|
173 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
174 |
+
else:
|
175 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
176 |
+
else:
|
177 |
+
if args.per_token:
|
178 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
179 |
+
else:
|
180 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
181 |
+
|
182 |
+
if args.int8_kv_cache:
|
183 |
+
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
184 |
+
|
185 |
+
if args.weight_only_precision == 'int4_gptq':
|
186 |
+
quant_config.group_size = args.group_size
|
187 |
+
quant_config.has_zero_point = True
|
188 |
+
quant_config.pre_quant_scale = False
|
189 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
190 |
+
|
191 |
+
return quant_config
|
192 |
+
|
193 |
+
|
194 |
+
def update_quant_config_from_hf(quant_config, hf_config,
|
195 |
+
override_fields) -> tuple[QuantConfig, dict]:
|
196 |
+
hf_config_dict = hf_config.to_dict()
|
197 |
+
if hf_config_dict.get('quantization_config'):
|
198 |
+
# update the quant_algo, and clamp_val.
|
199 |
+
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
|
200 |
+
logger.info(
|
201 |
+
"Load quantization configs from huggingface model_config.")
|
202 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
203 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
204 |
+
'group_size', 128)
|
205 |
+
quant_config.has_zero_point = hf_config_dict[
|
206 |
+
'quantization_config'].get('zero_point', False)
|
207 |
+
override_fields.update({"use_autoawq": True})
|
208 |
+
elif hf_config_dict['quantization_config'].get(
|
209 |
+
'quant_method') == 'gptq':
|
210 |
+
logger.info(
|
211 |
+
"Load quantization configs from huggingface model_config.")
|
212 |
+
desc_act = hf_config_dict['quantization_config'].get(
|
213 |
+
'desc_act', False)
|
214 |
+
if desc_act:
|
215 |
+
raise ValueError("GPTQ with desc_act=True is not implemented!")
|
216 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
217 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
218 |
+
'group_size', 128)
|
219 |
+
quant_config.has_zero_point = hf_config_dict[
|
220 |
+
'quantization_config'].get('sym', False)
|
221 |
+
return quant_config, override_fields
|
222 |
+
|
223 |
+
|
224 |
+
def args_to_build_options(args):
|
225 |
+
return {
|
226 |
+
'use_parallel_embedding': args.use_parallel_embedding,
|
227 |
+
'embedding_sharding_dim': args.embedding_sharding_dim,
|
228 |
+
'disable_weight_only_quant_plugin':
|
229 |
+
args.disable_weight_only_quant_plugin
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def convert_and_save_hf(args):
|
234 |
+
model_dir = args.model_dir
|
235 |
+
world_size = args.tp_size * args.pp_size
|
236 |
+
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
237 |
+
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
238 |
+
# before the refactor is done.
|
239 |
+
override_fields = {}
|
240 |
+
override_fields.update(args_to_build_options(args))
|
241 |
+
quant_config = args_to_quant_config(args)
|
242 |
+
|
243 |
+
try:
|
244 |
+
hf_config = AutoConfig.from_pretrained(model_dir,
|
245 |
+
trust_remote_code=True)
|
246 |
+
quant_config, override_fields = update_quant_config_from_hf(
|
247 |
+
quant_config, hf_config, override_fields)
|
248 |
+
except:
|
249 |
+
logger.warning("AutoConfig cannot load the huggingface config.")
|
250 |
+
|
251 |
+
if args.smoothquant is not None or args.int8_kv_cache:
|
252 |
+
mapping = Mapping(
|
253 |
+
world_size=world_size,
|
254 |
+
tp_size=args.tp_size,
|
255 |
+
pp_size=args.pp_size,
|
256 |
+
moe_tp_size=args.moe_tp_size,
|
257 |
+
moe_ep_size=args.moe_ep_size,
|
258 |
+
)
|
259 |
+
QWenForCausalLM.quantize(args.model_dir,
|
260 |
+
args.output_dir,
|
261 |
+
dtype=args.dtype,
|
262 |
+
mapping=mapping,
|
263 |
+
quant_config=quant_config,
|
264 |
+
calib_dataset=args.calib_dataset,
|
265 |
+
**override_fields)
|
266 |
+
else:
|
267 |
+
|
268 |
+
def convert_and_save_rank(args, rank):
|
269 |
+
mapping = Mapping(world_size=world_size,
|
270 |
+
rank=rank,
|
271 |
+
tp_size=args.tp_size,
|
272 |
+
pp_size=args.pp_size,
|
273 |
+
moe_tp_size=args.moe_tp_size,
|
274 |
+
moe_ep_size=args.moe_ep_size)
|
275 |
+
qwen = QWenForCausalLM.from_hugging_face(model_dir,
|
276 |
+
args.dtype,
|
277 |
+
mapping=mapping,
|
278 |
+
quant_config=quant_config,
|
279 |
+
**override_fields)
|
280 |
+
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
281 |
+
del qwen
|
282 |
+
|
283 |
+
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
284 |
+
release_gc()
|
285 |
+
|
286 |
+
|
287 |
+
def execute(workers, func, args):
|
288 |
+
if workers == 1:
|
289 |
+
for rank, f in enumerate(func):
|
290 |
+
f(args, rank)
|
291 |
+
else:
|
292 |
+
with ThreadPoolExecutor(max_workers=workers) as p:
|
293 |
+
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
294 |
+
exceptions = []
|
295 |
+
for future in as_completed(futures):
|
296 |
+
try:
|
297 |
+
future.result()
|
298 |
+
except Exception as e:
|
299 |
+
traceback.print_exc()
|
300 |
+
exceptions.append(e)
|
301 |
+
assert len(
|
302 |
+
exceptions
|
303 |
+
) == 0, "Checkpoint conversion failed, please check error log."
|
304 |
+
|
305 |
+
|
306 |
+
def main():
|
307 |
+
print(tensorrt_llm.__version__)
|
308 |
+
args = parse_arguments()
|
309 |
+
|
310 |
+
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
311 |
+
# moe default to tp-only
|
312 |
+
args.moe_tp_size = args.tp_size
|
313 |
+
args.moe_ep_size = 1
|
314 |
+
elif (args.moe_tp_size == -1):
|
315 |
+
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
316 |
+
elif (args.moe_ep_size == -1):
|
317 |
+
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
318 |
+
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
319 |
+
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
320 |
+
|
321 |
+
tik = time.time()
|
322 |
+
|
323 |
+
if not os.path.exists(args.output_dir):
|
324 |
+
os.makedirs(args.output_dir)
|
325 |
+
|
326 |
+
assert args.model_dir is not None
|
327 |
+
convert_and_save_hf(args)
|
328 |
+
|
329 |
+
tok = time.time()
|
330 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
331 |
+
print(f'Total time of converting checkpoints: {t}')
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == '__main__':
|
335 |
+
main()
|
runtime/triton_trtllm/scripts/fill_template.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from string import Template
|
4 |
+
|
5 |
+
|
6 |
+
def split(string, delimiter):
|
7 |
+
"""Split a string using delimiter. Supports escaping.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
string (str): The string to split.
|
11 |
+
delimiter (str): The delimiter to split the string with.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
list: A list of strings.
|
15 |
+
"""
|
16 |
+
result = []
|
17 |
+
current = ""
|
18 |
+
escape = False
|
19 |
+
for char in string:
|
20 |
+
if escape:
|
21 |
+
current += char
|
22 |
+
escape = False
|
23 |
+
elif char == delimiter:
|
24 |
+
result.append(current)
|
25 |
+
current = ""
|
26 |
+
elif char == "\\":
|
27 |
+
escape = True
|
28 |
+
else:
|
29 |
+
current += char
|
30 |
+
result.append(current)
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def main(file_path, substitutions, in_place):
|
35 |
+
with open(file_path) as f:
|
36 |
+
pbtxt = Template(f.read())
|
37 |
+
|
38 |
+
sub_dict = {
|
39 |
+
"max_queue_size": 0,
|
40 |
+
'max_queue_delay_microseconds': 0,
|
41 |
+
}
|
42 |
+
for sub in split(substitutions, ","):
|
43 |
+
key, value = split(sub, ":")
|
44 |
+
sub_dict[key] = value
|
45 |
+
|
46 |
+
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
|
47 |
+
|
48 |
+
pbtxt = pbtxt.safe_substitute(sub_dict)
|
49 |
+
|
50 |
+
if in_place:
|
51 |
+
with open(file_path, "w") as f:
|
52 |
+
f.write(pbtxt)
|
53 |
+
else:
|
54 |
+
print(pbtxt)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = ArgumentParser()
|
59 |
+
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
60 |
+
parser.add_argument(
|
61 |
+
"substitutions",
|
62 |
+
help=
|
63 |
+
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
64 |
+
)
|
65 |
+
parser.add_argument("--in_place",
|
66 |
+
"-i",
|
67 |
+
action="store_true",
|
68 |
+
help="do the operation in-place")
|
69 |
+
args = parser.parse_args()
|
70 |
+
main(**vars(args))
|
sparktts/models/audio_tokenizer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import Any, Dict, Tuple
|
22 |
+
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
23 |
+
|
24 |
+
from sparktts.utils.file import load_config
|
25 |
+
from sparktts.utils.audio import load_audio
|
26 |
+
from sparktts.models.bicodec import BiCodec
|
27 |
+
|
28 |
+
|
29 |
+
class BiCodecTokenizer:
|
30 |
+
"""BiCodec tokenizer for handling audio input and tokenization."""
|
31 |
+
|
32 |
+
def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
|
33 |
+
super().__init__()
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
model_dir: Path to the model directory.
|
37 |
+
device: Device to run the model on (default is GPU if available).
|
38 |
+
"""
|
39 |
+
self.device = device
|
40 |
+
self.model_dir = model_dir
|
41 |
+
self.config = load_config(f"{model_dir}/config.yaml")
|
42 |
+
self._initialize_model()
|
43 |
+
|
44 |
+
def _initialize_model(self):
|
45 |
+
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
|
46 |
+
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
|
47 |
+
self.device
|
48 |
+
)
|
49 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
50 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
51 |
+
)
|
52 |
+
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
53 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
54 |
+
).to(self.device)
|
55 |
+
self.feature_extractor.config.output_hidden_states = True
|
56 |
+
|
57 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
58 |
+
"""Get reference audio clip for speaker embedding."""
|
59 |
+
ref_segment_length = (
|
60 |
+
int(self.config["sample_rate"] * self.config["ref_segment_duration"])
|
61 |
+
// self.config["latent_hop_length"]
|
62 |
+
* self.config["latent_hop_length"]
|
63 |
+
)
|
64 |
+
wav_length = len(wav)
|
65 |
+
|
66 |
+
if ref_segment_length > wav_length:
|
67 |
+
# Repeat and truncate to handle insufficient length
|
68 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
69 |
+
|
70 |
+
return wav[:ref_segment_length]
|
71 |
+
|
72 |
+
def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
|
73 |
+
"""load auido and get reference audio from wav path"""
|
74 |
+
wav = load_audio(
|
75 |
+
wav_path,
|
76 |
+
sampling_rate=self.config["sample_rate"],
|
77 |
+
volume_normalize=self.config["volume_normalize"],
|
78 |
+
)
|
79 |
+
|
80 |
+
wav_ref = self.get_ref_clip(wav)
|
81 |
+
|
82 |
+
wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
|
83 |
+
return wav, wav_ref
|
84 |
+
|
85 |
+
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""extract wav2vec2 features"""
|
87 |
+
inputs = self.processor(
|
88 |
+
wavs,
|
89 |
+
sampling_rate=16000,
|
90 |
+
return_tensors="pt",
|
91 |
+
padding=True,
|
92 |
+
output_hidden_states=True,
|
93 |
+
).input_values
|
94 |
+
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
95 |
+
feats_mix = (
|
96 |
+
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
97 |
+
) / 3
|
98 |
+
|
99 |
+
return feats_mix
|
100 |
+
|
101 |
+
def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
|
102 |
+
"""tokenize the batch of audio
|
103 |
+
|
104 |
+
Args:
|
105 |
+
batch:
|
106 |
+
wavs (List[np.ndarray]): batch of audio
|
107 |
+
ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
|
111 |
+
global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
|
112 |
+
"""
|
113 |
+
feats = self.extract_wav2vec2_features(batch["wav"])
|
114 |
+
batch["feat"] = feats
|
115 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
116 |
+
|
117 |
+
return global_tokens, semantic_tokens
|
118 |
+
|
119 |
+
def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
120 |
+
"""tokenize the audio"""
|
121 |
+
wav, ref_wav = self.process_audio(audio_path)
|
122 |
+
feat = self.extract_wav2vec2_features(wav)
|
123 |
+
batch = {
|
124 |
+
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
|
125 |
+
"ref_wav": ref_wav.to(self.device),
|
126 |
+
"feat": feat.to(self.device),
|
127 |
+
}
|
128 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
129 |
+
|
130 |
+
return global_tokens, semantic_tokens
|
131 |
+
|
132 |
+
def detokenize(
|
133 |
+
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
|
134 |
+
) -> np.array:
|
135 |
+
"""detokenize the tokens to waveform
|
136 |
+
|
137 |
+
Args:
|
138 |
+
global_tokens: global tokens. shape: (batch_size, global_dim)
|
139 |
+
semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
|
143 |
+
"""
|
144 |
+
global_tokens = global_tokens.unsqueeze(1)
|
145 |
+
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
|
146 |
+
return wav_rec.detach().squeeze().cpu().numpy()
|
147 |
+
|
148 |
+
|
149 |
+
# test
|
150 |
+
if __name__ == "__main__":
|
151 |
+
import soundfile as sf
|
152 |
+
|
153 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
154 |
+
tokenizer = BiCodecTokenizer(
|
155 |
+
model_dir="pretrained_models/Spark-TTS-0.5B",
|
156 |
+
device=device,
|
157 |
+
)
|
158 |
+
wav_path = "example/prompt_audio.wav"
|
159 |
+
|
160 |
+
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
|
161 |
+
|
162 |
+
wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
|
163 |
+
sf.write("example/prompt_recon.wav", wav_rec, 16000)
|
sparktts/models/bicodec.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Dict, Any
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from sparktts.utils.file import load_config
|
24 |
+
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
|
25 |
+
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
|
26 |
+
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
|
27 |
+
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
|
28 |
+
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
29 |
+
|
30 |
+
|
31 |
+
class BiCodec(nn.Module):
|
32 |
+
"""
|
33 |
+
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
|
34 |
+
quantizer, and wave generator.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
mel_params: Dict[str, Any],
|
40 |
+
encoder: nn.Module,
|
41 |
+
decoder: nn.Module,
|
42 |
+
quantizer: nn.Module,
|
43 |
+
speaker_encoder: nn.Module,
|
44 |
+
prenet: nn.Module,
|
45 |
+
postnet: nn.Module,
|
46 |
+
**kwargs
|
47 |
+
) -> None:
|
48 |
+
"""
|
49 |
+
Initializes the BiCodec model with the required components.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
mel_params (dict): Parameters for the mel-spectrogram transformer.
|
53 |
+
encoder (nn.Module): Encoder module.
|
54 |
+
decoder (nn.Module): Decoder module.
|
55 |
+
quantizer (nn.Module): Quantizer module.
|
56 |
+
speaker_encoder (nn.Module): Speaker encoder module.
|
57 |
+
prenet (nn.Module): Prenet network.
|
58 |
+
postnet (nn.Module): Postnet network.
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.encoder = encoder
|
62 |
+
self.decoder = decoder
|
63 |
+
self.quantizer = quantizer
|
64 |
+
self.speaker_encoder = speaker_encoder
|
65 |
+
self.prenet = prenet
|
66 |
+
self.postnet = postnet
|
67 |
+
self.init_mel_transformer(mel_params)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
|
71 |
+
"""
|
72 |
+
Loads the model from a checkpoint.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
model_dir (Path): Path to the model directory containing checkpoint and config.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
BiCodec: The initialized BiCodec model.
|
79 |
+
"""
|
80 |
+
ckpt_path = f'{model_dir}/model.safetensors'
|
81 |
+
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
|
82 |
+
mel_params = config["mel_params"]
|
83 |
+
encoder = Encoder(**config["encoder"])
|
84 |
+
quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
85 |
+
prenet = Decoder(**config["prenet"])
|
86 |
+
postnet = Decoder(**config["postnet"])
|
87 |
+
decoder = WaveGenerator(**config["decoder"])
|
88 |
+
speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
89 |
+
|
90 |
+
model = cls(
|
91 |
+
mel_params=mel_params,
|
92 |
+
encoder=encoder,
|
93 |
+
decoder=decoder,
|
94 |
+
quantizer=quantizer,
|
95 |
+
speaker_encoder=speaker_encoder,
|
96 |
+
prenet=prenet,
|
97 |
+
postnet=postnet,
|
98 |
+
)
|
99 |
+
|
100 |
+
state_dict = load_file(ckpt_path)
|
101 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
102 |
+
|
103 |
+
for key in missing_keys:
|
104 |
+
print(f"Missing tensor: {key}")
|
105 |
+
for key in unexpected_keys:
|
106 |
+
print(f"Unexpected tensor: {key}")
|
107 |
+
|
108 |
+
model.eval()
|
109 |
+
model.remove_weight_norm()
|
110 |
+
|
111 |
+
return model
|
112 |
+
|
113 |
+
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
114 |
+
"""
|
115 |
+
Performs a forward pass through the model.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
batch (dict): A dictionary containing features, reference waveform, and target waveform.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
dict: A dictionary containing the reconstruction, features, and other metrics.
|
122 |
+
"""
|
123 |
+
feat = batch["feat"]
|
124 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
125 |
+
|
126 |
+
z = self.encoder(feat.transpose(1, 2))
|
127 |
+
vq_outputs = self.quantizer(z)
|
128 |
+
|
129 |
+
x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
|
130 |
+
|
131 |
+
conditions = d_vector
|
132 |
+
with_speaker_loss = False
|
133 |
+
|
134 |
+
x = self.prenet(vq_outputs["z_q"], conditions)
|
135 |
+
pred_feat = self.postnet(x)
|
136 |
+
x = x + conditions.unsqueeze(-1)
|
137 |
+
wav_recon = self.decoder(x)
|
138 |
+
|
139 |
+
return {
|
140 |
+
"vq_loss": vq_outputs["vq_loss"],
|
141 |
+
"perplexity": vq_outputs["perplexity"],
|
142 |
+
"cluster_size": vq_outputs["active_num"],
|
143 |
+
"recons": wav_recon,
|
144 |
+
"pred_feat": pred_feat,
|
145 |
+
"x_vector": x_vector,
|
146 |
+
"d_vector": d_vector,
|
147 |
+
"audios": batch["wav"].unsqueeze(1),
|
148 |
+
"with_speaker_loss": with_speaker_loss,
|
149 |
+
}
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def tokenize(self, batch: Dict[str, Any]):
|
153 |
+
"""
|
154 |
+
Tokenizes the input audio into semantic and global tokens.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
batch (dict): The input audio features and reference waveform.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
tuple: Semantic tokens and global tokens.
|
161 |
+
"""
|
162 |
+
feat = batch["feat"]
|
163 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
164 |
+
|
165 |
+
z = self.encoder(feat.transpose(1, 2))
|
166 |
+
semantic_tokens = self.quantizer.tokenize(z)
|
167 |
+
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
|
168 |
+
|
169 |
+
return semantic_tokens, global_tokens
|
170 |
+
|
171 |
+
@torch.no_grad()
|
172 |
+
def detokenize(self, semantic_tokens, global_tokens):
|
173 |
+
"""
|
174 |
+
Detokenizes the semantic and global tokens into a waveform.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
semantic_tokens (tensor): Semantic tokens.
|
178 |
+
global_tokens (tensor): Global tokens.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
tensor: Reconstructed waveform.
|
182 |
+
"""
|
183 |
+
z_q = self.quantizer.detokenize(semantic_tokens)
|
184 |
+
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
185 |
+
x = self.prenet(z_q, d_vector)
|
186 |
+
x = x + d_vector.unsqueeze(-1)
|
187 |
+
wav_recon = self.decoder(x)
|
188 |
+
|
189 |
+
return wav_recon
|
190 |
+
|
191 |
+
def init_mel_transformer(self, config: Dict[str, Any]):
|
192 |
+
"""
|
193 |
+
Initializes the MelSpectrogram transformer based on the provided configuration.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
config (dict): Configuration parameters for MelSpectrogram.
|
197 |
+
"""
|
198 |
+
import torchaudio.transforms as TT
|
199 |
+
|
200 |
+
self.mel_transformer = TT.MelSpectrogram(
|
201 |
+
config["sample_rate"],
|
202 |
+
config["n_fft"],
|
203 |
+
config["win_length"],
|
204 |
+
config["hop_length"],
|
205 |
+
config["mel_fmin"],
|
206 |
+
config["mel_fmax"],
|
207 |
+
n_mels=config["num_mels"],
|
208 |
+
power=1,
|
209 |
+
norm="slaney",
|
210 |
+
mel_scale="slaney",
|
211 |
+
)
|
212 |
+
|
213 |
+
def remove_weight_norm(self):
|
214 |
+
"""Removes weight normalization from all layers."""
|
215 |
+
def _remove_weight_norm(m):
|
216 |
+
try:
|
217 |
+
torch.nn.utils.remove_weight_norm(m)
|
218 |
+
except ValueError:
|
219 |
+
pass # The module didn't have weight norm
|
220 |
+
|
221 |
+
self.apply(_remove_weight_norm)
|
222 |
+
|
223 |
+
|
224 |
+
# Test the model
|
225 |
+
if __name__ == "__main__":
|
226 |
+
|
227 |
+
config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
|
228 |
+
model = BiCodec.load_from_checkpoint(
|
229 |
+
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
|
230 |
+
)
|
231 |
+
|
232 |
+
# Generate random inputs for testing
|
233 |
+
duration = 0.96
|
234 |
+
x = torch.randn(20, 1, int(duration * 16000))
|
235 |
+
feat = torch.randn(20, int(duration * 50), 1024)
|
236 |
+
inputs = {"feat": feat, "wav": x, "ref_wav": x}
|
237 |
+
|
238 |
+
# Forward pass
|
239 |
+
outputs = model(inputs)
|
240 |
+
semantic_tokens, global_tokens = model.tokenize(inputs)
|
241 |
+
wav_recon = model.detokenize(semantic_tokens, global_tokens)
|
242 |
+
|
243 |
+
# Verify if the reconstruction matches
|
244 |
+
if torch.allclose(outputs["recons"].detach(), wav_recon):
|
245 |
+
print("Test successful")
|
246 |
+
else:
|
247 |
+
print("Test failed")
|
sparktts/modules/blocks/layers.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn.utils import weight_norm
|
22 |
+
|
23 |
+
|
24 |
+
def WNConv1d(*args, **kwargs):
|
25 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
26 |
+
|
27 |
+
|
28 |
+
def WNConvTranspose1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
# Scripting this brings model speed up 1.4x
|
33 |
+
@torch.jit.script
|
34 |
+
def snake(x, alpha):
|
35 |
+
shape = x.shape
|
36 |
+
x = x.reshape(shape[0], shape[1], -1)
|
37 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
38 |
+
x = x.reshape(shape)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Snake1d(nn.Module):
|
43 |
+
def __init__(self, channels):
|
44 |
+
super().__init__()
|
45 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return snake(x, self.alpha)
|
49 |
+
|
50 |
+
|
51 |
+
class ResidualUnit(nn.Module):
|
52 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
53 |
+
super().__init__()
|
54 |
+
pad = ((7 - 1) * dilation) // 2
|
55 |
+
self.block = nn.Sequential(
|
56 |
+
Snake1d(dim),
|
57 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
58 |
+
Snake1d(dim),
|
59 |
+
WNConv1d(dim, dim, kernel_size=1),
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
y = self.block(x)
|
64 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
65 |
+
if pad > 0:
|
66 |
+
x = x[..., pad:-pad]
|
67 |
+
return x + y
|
68 |
+
|
69 |
+
|
70 |
+
def init_weights(m):
|
71 |
+
if isinstance(m, nn.Conv1d):
|
72 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
73 |
+
nn.init.constant_(m.bias, 0)
|
sparktts/modules/blocks/samper.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
|
22 |
+
class SamplingBlock(nn.Module):
|
23 |
+
"""Sampling block for upsampling or downsampling"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
dim: int,
|
28 |
+
groups: int = 1,
|
29 |
+
upsample_scale: int = 1,
|
30 |
+
downsample_scale: int = 1,
|
31 |
+
) -> None:
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
dim: input dimension
|
35 |
+
groups: number of groups
|
36 |
+
upsample_scale: upsampling scale
|
37 |
+
downsample_scale: downsampling scale
|
38 |
+
"""
|
39 |
+
super(SamplingBlock, self).__init__()
|
40 |
+
|
41 |
+
self.upsample_scale = upsample_scale
|
42 |
+
self.downsample_scale = downsample_scale
|
43 |
+
|
44 |
+
if self.upsample_scale > 1:
|
45 |
+
self.de_conv_upsampler = nn.Sequential(
|
46 |
+
nn.LeakyReLU(0.2),
|
47 |
+
nn.ConvTranspose1d(
|
48 |
+
dim,
|
49 |
+
dim,
|
50 |
+
kernel_size=upsample_scale * 2,
|
51 |
+
stride=upsample_scale,
|
52 |
+
padding=upsample_scale // 2 + upsample_scale % 2,
|
53 |
+
output_padding=upsample_scale % 2,
|
54 |
+
groups=groups,
|
55 |
+
),
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.downsample_scale > 1:
|
59 |
+
self.conv_downsampler = nn.Sequential(
|
60 |
+
nn.LeakyReLU(0.2),
|
61 |
+
nn.Conv1d(
|
62 |
+
dim,
|
63 |
+
dim,
|
64 |
+
kernel_size=2 * downsample_scale,
|
65 |
+
stride=downsample_scale,
|
66 |
+
padding=downsample_scale // 2 + downsample_scale % 2,
|
67 |
+
groups=groups,
|
68 |
+
),
|
69 |
+
)
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def repeat_upsampler(x, upsample_scale):
|
73 |
+
return x.repeat_interleave(upsample_scale, dim=2)
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def skip_downsampler(x, downsample_scale):
|
77 |
+
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = x.transpose(1, 2)
|
81 |
+
if self.upsample_scale > 1:
|
82 |
+
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
|
83 |
+
deconv_res = self.de_conv_upsampler(x)
|
84 |
+
upmerge_res = repeat_res + deconv_res
|
85 |
+
else:
|
86 |
+
upmerge_res = x
|
87 |
+
repeat_res = x
|
88 |
+
|
89 |
+
if self.downsample_scale > 1:
|
90 |
+
conv_res = self.conv_downsampler(upmerge_res)
|
91 |
+
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
|
92 |
+
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
|
93 |
+
else:
|
94 |
+
conv_res = upmerge_res
|
95 |
+
skip2_res = upmerge_res
|
96 |
+
skip1_res = repeat_res
|
97 |
+
|
98 |
+
final_res = conv_res + skip1_res + skip2_res
|
99 |
+
|
100 |
+
return final_res
|
101 |
+
|
102 |
+
|
103 |
+
# test
|
104 |
+
if __name__ == "__main__":
|
105 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
106 |
+
model = SamplingBlock(1024, 1024, upsample_scale=2)
|
107 |
+
model_down = SamplingBlock(1024, 1024, downsample_scale=2)
|
108 |
+
output = model(test_input)
|
109 |
+
output_down = model_down(test_input)
|
110 |
+
print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
|
111 |
+
print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
|
112 |
+
if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
|
113 |
+
[8, 1024, 25]
|
114 |
+
):
|
115 |
+
print("test successful")
|
sparktts/modules/blocks/vocos.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import Tuple
|
21 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
22 |
+
|
23 |
+
from typing import Optional
|
24 |
+
|
25 |
+
|
26 |
+
class ConvNeXtBlock(nn.Module):
|
27 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dim (int): Number of input channels.
|
31 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
32 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
33 |
+
Defaults to None.
|
34 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
35 |
+
None means non-conditional LayerNorm. Defaults to None.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
dim: int,
|
41 |
+
intermediate_dim: int,
|
42 |
+
layer_scale_init_value: float,
|
43 |
+
condition_dim: Optional[int] = None,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.dwconv = nn.Conv1d(
|
47 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
48 |
+
) # depthwise conv
|
49 |
+
self.adanorm = condition_dim is not None
|
50 |
+
if condition_dim:
|
51 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
52 |
+
else:
|
53 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
54 |
+
self.pwconv1 = nn.Linear(
|
55 |
+
dim, intermediate_dim
|
56 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
57 |
+
self.act = nn.GELU()
|
58 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
59 |
+
self.gamma = (
|
60 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
61 |
+
if layer_scale_init_value > 0
|
62 |
+
else None
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
67 |
+
) -> torch.Tensor:
|
68 |
+
residual = x
|
69 |
+
x = self.dwconv(x)
|
70 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
71 |
+
if self.adanorm:
|
72 |
+
assert cond_embedding_id is not None
|
73 |
+
x = self.norm(x, cond_embedding_id)
|
74 |
+
else:
|
75 |
+
x = self.norm(x)
|
76 |
+
x = self.pwconv1(x)
|
77 |
+
x = self.act(x)
|
78 |
+
x = self.pwconv2(x)
|
79 |
+
if self.gamma is not None:
|
80 |
+
x = self.gamma * x
|
81 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
82 |
+
|
83 |
+
x = residual + x
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class AdaLayerNorm(nn.Module):
|
88 |
+
"""
|
89 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
90 |
+
|
91 |
+
Args:
|
92 |
+
condition_dim (int): Dimension of the condition.
|
93 |
+
embedding_dim (int): Dimension of the embeddings.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
|
97 |
+
super().__init__()
|
98 |
+
self.eps = eps
|
99 |
+
self.dim = embedding_dim
|
100 |
+
self.scale = nn.Linear(condition_dim, embedding_dim)
|
101 |
+
self.shift = nn.Linear(condition_dim, embedding_dim)
|
102 |
+
torch.nn.init.ones_(self.scale.weight)
|
103 |
+
torch.nn.init.zeros_(self.shift.weight)
|
104 |
+
|
105 |
+
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
|
106 |
+
scale = self.scale(cond_embedding)
|
107 |
+
shift = self.shift(cond_embedding)
|
108 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
109 |
+
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class ResBlock1(nn.Module):
|
114 |
+
"""
|
115 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
116 |
+
but without upsampling layers.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
dim (int): Number of input channels.
|
120 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
121 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
122 |
+
Defaults to (1, 3, 5).
|
123 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
124 |
+
Defaults to 0.1.
|
125 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
126 |
+
Defaults to None.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
dim: int,
|
132 |
+
kernel_size: int = 3,
|
133 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
134 |
+
lrelu_slope: float = 0.1,
|
135 |
+
layer_scale_init_value: Optional[float] = None,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.lrelu_slope = lrelu_slope
|
139 |
+
self.convs1 = nn.ModuleList(
|
140 |
+
[
|
141 |
+
weight_norm(
|
142 |
+
nn.Conv1d(
|
143 |
+
dim,
|
144 |
+
dim,
|
145 |
+
kernel_size,
|
146 |
+
1,
|
147 |
+
dilation=dilation[0],
|
148 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
149 |
+
)
|
150 |
+
),
|
151 |
+
weight_norm(
|
152 |
+
nn.Conv1d(
|
153 |
+
dim,
|
154 |
+
dim,
|
155 |
+
kernel_size,
|
156 |
+
1,
|
157 |
+
dilation=dilation[1],
|
158 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
159 |
+
)
|
160 |
+
),
|
161 |
+
weight_norm(
|
162 |
+
nn.Conv1d(
|
163 |
+
dim,
|
164 |
+
dim,
|
165 |
+
kernel_size,
|
166 |
+
1,
|
167 |
+
dilation=dilation[2],
|
168 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
169 |
+
)
|
170 |
+
),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
self.convs2 = nn.ModuleList(
|
175 |
+
[
|
176 |
+
weight_norm(
|
177 |
+
nn.Conv1d(
|
178 |
+
dim,
|
179 |
+
dim,
|
180 |
+
kernel_size,
|
181 |
+
1,
|
182 |
+
dilation=1,
|
183 |
+
padding=self.get_padding(kernel_size, 1),
|
184 |
+
)
|
185 |
+
),
|
186 |
+
weight_norm(
|
187 |
+
nn.Conv1d(
|
188 |
+
dim,
|
189 |
+
dim,
|
190 |
+
kernel_size,
|
191 |
+
1,
|
192 |
+
dilation=1,
|
193 |
+
padding=self.get_padding(kernel_size, 1),
|
194 |
+
)
|
195 |
+
),
|
196 |
+
weight_norm(
|
197 |
+
nn.Conv1d(
|
198 |
+
dim,
|
199 |
+
dim,
|
200 |
+
kernel_size,
|
201 |
+
1,
|
202 |
+
dilation=1,
|
203 |
+
padding=self.get_padding(kernel_size, 1),
|
204 |
+
)
|
205 |
+
),
|
206 |
+
]
|
207 |
+
)
|
208 |
+
|
209 |
+
self.gamma = nn.ParameterList(
|
210 |
+
[
|
211 |
+
(
|
212 |
+
nn.Parameter(
|
213 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
214 |
+
)
|
215 |
+
if layer_scale_init_value is not None
|
216 |
+
else None
|
217 |
+
),
|
218 |
+
(
|
219 |
+
nn.Parameter(
|
220 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
221 |
+
)
|
222 |
+
if layer_scale_init_value is not None
|
223 |
+
else None
|
224 |
+
),
|
225 |
+
(
|
226 |
+
nn.Parameter(
|
227 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
228 |
+
)
|
229 |
+
if layer_scale_init_value is not None
|
230 |
+
else None
|
231 |
+
),
|
232 |
+
]
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
236 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
237 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
238 |
+
xt = c1(xt)
|
239 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
240 |
+
xt = c2(xt)
|
241 |
+
if gamma is not None:
|
242 |
+
xt = gamma * xt
|
243 |
+
x = xt + x
|
244 |
+
return x
|
245 |
+
|
246 |
+
def remove_weight_norm(self):
|
247 |
+
for l in self.convs1:
|
248 |
+
remove_weight_norm(l)
|
249 |
+
for l in self.convs2:
|
250 |
+
remove_weight_norm(l)
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
254 |
+
return int((kernel_size * dilation - dilation) / 2)
|
255 |
+
|
256 |
+
|
257 |
+
class Backbone(nn.Module):
|
258 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
259 |
+
|
260 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
261 |
+
"""
|
262 |
+
Args:
|
263 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
264 |
+
C denotes output features, and L is the sequence length.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
268 |
+
and H denotes the model dimension.
|
269 |
+
"""
|
270 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
271 |
+
|
272 |
+
|
273 |
+
class VocosBackbone(Backbone):
|
274 |
+
"""
|
275 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
276 |
+
|
277 |
+
Args:
|
278 |
+
input_channels (int): Number of input features channels.
|
279 |
+
dim (int): Hidden dimension of the model.
|
280 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
281 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
282 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
283 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
284 |
+
None means non-conditional model. Defaults to None.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
input_channels: int,
|
290 |
+
dim: int,
|
291 |
+
intermediate_dim: int,
|
292 |
+
num_layers: int,
|
293 |
+
layer_scale_init_value: Optional[float] = None,
|
294 |
+
condition_dim: Optional[int] = None,
|
295 |
+
):
|
296 |
+
super().__init__()
|
297 |
+
self.input_channels = input_channels
|
298 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
299 |
+
self.adanorm = condition_dim is not None
|
300 |
+
if condition_dim:
|
301 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
302 |
+
else:
|
303 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
304 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
305 |
+
self.convnext = nn.ModuleList(
|
306 |
+
[
|
307 |
+
ConvNeXtBlock(
|
308 |
+
dim=dim,
|
309 |
+
intermediate_dim=intermediate_dim,
|
310 |
+
layer_scale_init_value=layer_scale_init_value,
|
311 |
+
condition_dim=condition_dim,
|
312 |
+
)
|
313 |
+
for _ in range(num_layers)
|
314 |
+
]
|
315 |
+
)
|
316 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
317 |
+
self.apply(self._init_weights)
|
318 |
+
|
319 |
+
def _init_weights(self, m):
|
320 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
321 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
322 |
+
nn.init.constant_(m.bias, 0)
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
|
325 |
+
x = self.embed(x)
|
326 |
+
if self.adanorm:
|
327 |
+
assert condition is not None
|
328 |
+
x = self.norm(x.transpose(1, 2), condition)
|
329 |
+
else:
|
330 |
+
x = self.norm(x.transpose(1, 2))
|
331 |
+
x = x.transpose(1, 2)
|
332 |
+
for conv_block in self.convnext:
|
333 |
+
x = conv_block(x, condition)
|
334 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
335 |
+
return x
|
336 |
+
|
337 |
+
|
338 |
+
class VocosResNetBackbone(Backbone):
|
339 |
+
"""
|
340 |
+
Vocos backbone module built with ResBlocks.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
input_channels (int): Number of input features channels.
|
344 |
+
dim (int): Hidden dimension of the model.
|
345 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
346 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
347 |
+
"""
|
348 |
+
|
349 |
+
def __init__(
|
350 |
+
self,
|
351 |
+
input_channels,
|
352 |
+
dim,
|
353 |
+
num_blocks,
|
354 |
+
layer_scale_init_value=None,
|
355 |
+
):
|
356 |
+
super().__init__()
|
357 |
+
self.input_channels = input_channels
|
358 |
+
self.embed = weight_norm(
|
359 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
360 |
+
)
|
361 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
362 |
+
self.resnet = nn.Sequential(
|
363 |
+
*[
|
364 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
365 |
+
for _ in range(num_blocks)
|
366 |
+
]
|
367 |
+
)
|
368 |
+
|
369 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
370 |
+
x = self.embed(x)
|
371 |
+
x = self.resnet(x)
|
372 |
+
x = x.transpose(1, 2)
|
373 |
+
return x
|
sparktts/modules/encoder_decoder/feat_decoder.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
+
|
25 |
+
|
26 |
+
class Decoder(nn.Module):
|
27 |
+
"""Decoder module with convnext and upsampling blocks
|
28 |
+
|
29 |
+
Args:
|
30 |
+
sample_ratios (List[int]): sample ratios
|
31 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
input_channels: int,
|
37 |
+
vocos_dim: int,
|
38 |
+
vocos_intermediate_dim: int,
|
39 |
+
vocos_num_layers: int,
|
40 |
+
out_channels: int,
|
41 |
+
condition_dim: int = None,
|
42 |
+
sample_ratios: List[int] = [1, 1],
|
43 |
+
use_tanh_at_final: bool = False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.linear_pre = nn.Linear(input_channels, vocos_dim)
|
48 |
+
modules = [
|
49 |
+
nn.Sequential(
|
50 |
+
SamplingBlock(
|
51 |
+
dim=vocos_dim,
|
52 |
+
groups=vocos_dim,
|
53 |
+
upsample_scale=ratio,
|
54 |
+
),
|
55 |
+
VocosBackbone(
|
56 |
+
input_channels=vocos_dim,
|
57 |
+
dim=vocos_dim,
|
58 |
+
intermediate_dim=vocos_intermediate_dim,
|
59 |
+
num_layers=2,
|
60 |
+
condition_dim=None,
|
61 |
+
),
|
62 |
+
)
|
63 |
+
for ratio in sample_ratios
|
64 |
+
]
|
65 |
+
|
66 |
+
self.downsample = nn.Sequential(*modules)
|
67 |
+
|
68 |
+
self.vocos_backbone = VocosBackbone(
|
69 |
+
input_channels=vocos_dim,
|
70 |
+
dim=vocos_dim,
|
71 |
+
intermediate_dim=vocos_intermediate_dim,
|
72 |
+
num_layers=vocos_num_layers,
|
73 |
+
condition_dim=condition_dim,
|
74 |
+
)
|
75 |
+
self.linear = nn.Linear(vocos_dim, out_channels)
|
76 |
+
self.use_tanh_at_final = use_tanh_at_final
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
|
79 |
+
"""encoder forward.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
86 |
+
"""
|
87 |
+
x = self.linear_pre(x.transpose(1, 2))
|
88 |
+
x = self.downsample(x).transpose(1, 2)
|
89 |
+
x = self.vocos_backbone(x, condition=c)
|
90 |
+
x = self.linear(x).transpose(1, 2)
|
91 |
+
if self.use_tanh_at_final:
|
92 |
+
x = torch.tanh(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
# test
|
98 |
+
if __name__ == "__main__":
|
99 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
100 |
+
condition = torch.randn(8, 256)
|
101 |
+
decoder = Decoder(
|
102 |
+
input_channels=1024,
|
103 |
+
vocos_dim=384,
|
104 |
+
vocos_intermediate_dim=2048,
|
105 |
+
vocos_num_layers=12,
|
106 |
+
out_channels=256,
|
107 |
+
condition_dim=256,
|
108 |
+
sample_ratios=[2, 2],
|
109 |
+
)
|
110 |
+
output = decoder(test_input, condition)
|
111 |
+
print(output.shape) # torch.Size([8, 256, 200])
|
112 |
+
if output.shape == torch.Size([8, 256, 200]):
|
113 |
+
print("Decoder test passed")
|
114 |
+
else:
|
115 |
+
print("Decoder test failed")
|
sparktts/modules/encoder_decoder/feat_encoder.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
+
|
25 |
+
|
26 |
+
class Encoder(nn.Module):
|
27 |
+
"""Encoder module with convnext and downsampling blocks"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
input_channels: int,
|
32 |
+
vocos_dim: int,
|
33 |
+
vocos_intermediate_dim: int,
|
34 |
+
vocos_num_layers: int,
|
35 |
+
out_channels: int,
|
36 |
+
sample_ratios: List[int] = [1, 1],
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
"""
|
40 |
+
Encoder module with VocosBackbone and sampling blocks.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
sample_ratios (List[int]): sample ratios
|
44 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
45 |
+
"""
|
46 |
+
self.encoder = VocosBackbone(
|
47 |
+
input_channels=input_channels,
|
48 |
+
dim=vocos_dim,
|
49 |
+
intermediate_dim=vocos_intermediate_dim,
|
50 |
+
num_layers=vocos_num_layers,
|
51 |
+
condition_dim=None,
|
52 |
+
)
|
53 |
+
|
54 |
+
modules = [
|
55 |
+
nn.Sequential(
|
56 |
+
SamplingBlock(
|
57 |
+
dim=vocos_dim,
|
58 |
+
groups=vocos_dim,
|
59 |
+
downsample_scale=ratio,
|
60 |
+
),
|
61 |
+
VocosBackbone(
|
62 |
+
input_channels=vocos_dim,
|
63 |
+
dim=vocos_dim,
|
64 |
+
intermediate_dim=vocos_intermediate_dim,
|
65 |
+
num_layers=2,
|
66 |
+
condition_dim=None,
|
67 |
+
),
|
68 |
+
)
|
69 |
+
for ratio in sample_ratios
|
70 |
+
]
|
71 |
+
|
72 |
+
self.downsample = nn.Sequential(*modules)
|
73 |
+
|
74 |
+
self.project = nn.Linear(vocos_dim, out_channels)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, *args):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
83 |
+
"""
|
84 |
+
x = self.encoder(x)
|
85 |
+
x = self.downsample(x)
|
86 |
+
x = self.project(x)
|
87 |
+
return x.transpose(1, 2)
|
88 |
+
|
89 |
+
|
90 |
+
# test
|
91 |
+
if __name__ == "__main__":
|
92 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
93 |
+
encoder = Encoder(
|
94 |
+
input_channels=1024,
|
95 |
+
vocos_dim=384,
|
96 |
+
vocos_intermediate_dim=2048,
|
97 |
+
vocos_num_layers=12,
|
98 |
+
out_channels=256,
|
99 |
+
sample_ratios=[2, 2],
|
100 |
+
)
|
101 |
+
|
102 |
+
output = encoder(test_input)
|
103 |
+
print(output.shape) # torch.Size([8, 256, 12])
|
104 |
+
if output.shape == torch.Size([8, 256, 12]):
|
105 |
+
print("test successful")
|
sparktts/modules/encoder_decoder/wave_generator.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Xinsheng Wang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
16 |
+
|
17 |
+
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from sparktts.modules.blocks.layers import (
|
21 |
+
Snake1d,
|
22 |
+
WNConv1d,
|
23 |
+
ResidualUnit,
|
24 |
+
WNConvTranspose1d,
|
25 |
+
init_weights,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class DecoderBlock(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
input_dim: int = 16,
|
33 |
+
output_dim: int = 8,
|
34 |
+
kernel_size: int = 2,
|
35 |
+
stride: int = 1,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
self.block = nn.Sequential(
|
39 |
+
Snake1d(input_dim),
|
40 |
+
WNConvTranspose1d(
|
41 |
+
input_dim,
|
42 |
+
output_dim,
|
43 |
+
kernel_size=kernel_size,
|
44 |
+
stride=stride,
|
45 |
+
padding=(kernel_size - stride) // 2,
|
46 |
+
),
|
47 |
+
ResidualUnit(output_dim, dilation=1),
|
48 |
+
ResidualUnit(output_dim, dilation=3),
|
49 |
+
ResidualUnit(output_dim, dilation=9),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return self.block(x)
|
54 |
+
|
55 |
+
|
56 |
+
class WaveGenerator(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
input_channel,
|
60 |
+
channels,
|
61 |
+
rates,
|
62 |
+
kernel_sizes,
|
63 |
+
d_out: int = 1,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
# Add first conv layer
|
68 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
69 |
+
|
70 |
+
# Add upsampling + MRF blocks
|
71 |
+
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
|
72 |
+
input_dim = channels // 2**i
|
73 |
+
output_dim = channels // 2 ** (i + 1)
|
74 |
+
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
|
75 |
+
|
76 |
+
# Add final conv layer
|
77 |
+
layers += [
|
78 |
+
Snake1d(output_dim),
|
79 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
80 |
+
nn.Tanh(),
|
81 |
+
]
|
82 |
+
|
83 |
+
self.model = nn.Sequential(*layers)
|
84 |
+
|
85 |
+
self.apply(init_weights)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return self.model(x)
|
sparktts/modules/fsq/finite_scalar_quantization.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
3 |
+
Code adapted from Jax version in Appendix A.1
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from functools import wraps, partial
|
8 |
+
from contextlib import nullcontext
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import Module
|
14 |
+
from torch import Tensor, int32
|
15 |
+
from torch.amp import autocast
|
16 |
+
|
17 |
+
from einops import rearrange, pack, unpack
|
18 |
+
|
19 |
+
# helper functions
|
20 |
+
|
21 |
+
|
22 |
+
def exists(v):
|
23 |
+
return v is not None
|
24 |
+
|
25 |
+
|
26 |
+
def default(*args):
|
27 |
+
for arg in args:
|
28 |
+
if exists(arg):
|
29 |
+
return arg
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
def maybe(fn):
|
34 |
+
@wraps(fn)
|
35 |
+
def inner(x, *args, **kwargs):
|
36 |
+
if not exists(x):
|
37 |
+
return x
|
38 |
+
return fn(x, *args, **kwargs)
|
39 |
+
|
40 |
+
return inner
|
41 |
+
|
42 |
+
|
43 |
+
def pack_one(t, pattern):
|
44 |
+
return pack([t], pattern)
|
45 |
+
|
46 |
+
|
47 |
+
def unpack_one(t, ps, pattern):
|
48 |
+
return unpack(t, ps, pattern)[0]
|
49 |
+
|
50 |
+
|
51 |
+
# tensor helpers
|
52 |
+
|
53 |
+
|
54 |
+
def round_ste(z: Tensor) -> Tensor:
|
55 |
+
"""Round with straight through gradients."""
|
56 |
+
zhat = z.round()
|
57 |
+
return z + (zhat - z).detach()
|
58 |
+
|
59 |
+
|
60 |
+
# main class
|
61 |
+
|
62 |
+
|
63 |
+
class FSQ(Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
levels: List[int],
|
67 |
+
dim: int | None = None,
|
68 |
+
num_codebooks=1,
|
69 |
+
keep_num_codebooks_dim: bool | None = None,
|
70 |
+
scale: float | None = None,
|
71 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
72 |
+
channel_first: bool = False,
|
73 |
+
projection_has_bias: bool = True,
|
74 |
+
return_indices=True,
|
75 |
+
force_quantization_f32=True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
_levels = torch.tensor(levels, dtype=int32)
|
79 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
80 |
+
|
81 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
82 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
83 |
+
|
84 |
+
self.scale = scale
|
85 |
+
|
86 |
+
codebook_dim = len(levels)
|
87 |
+
self.codebook_dim = codebook_dim
|
88 |
+
|
89 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
90 |
+
self.num_codebooks = num_codebooks
|
91 |
+
self.effective_codebook_dim = effective_codebook_dim
|
92 |
+
|
93 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
94 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
95 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
96 |
+
|
97 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
98 |
+
|
99 |
+
self.channel_first = channel_first
|
100 |
+
|
101 |
+
has_projections = self.dim != effective_codebook_dim
|
102 |
+
self.project_in = (
|
103 |
+
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
104 |
+
if has_projections
|
105 |
+
else nn.Identity()
|
106 |
+
)
|
107 |
+
self.project_out = (
|
108 |
+
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
109 |
+
if has_projections
|
110 |
+
else nn.Identity()
|
111 |
+
)
|
112 |
+
|
113 |
+
self.has_projections = has_projections
|
114 |
+
|
115 |
+
self.return_indices = return_indices
|
116 |
+
if return_indices:
|
117 |
+
self.codebook_size = self._levels.prod().item()
|
118 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
119 |
+
self.register_buffer(
|
120 |
+
"implicit_codebook", implicit_codebook, persistent=False
|
121 |
+
)
|
122 |
+
|
123 |
+
self.allowed_dtypes = allowed_dtypes
|
124 |
+
self.force_quantization_f32 = force_quantization_f32
|
125 |
+
|
126 |
+
def bound(self, z, eps: float = 1e-3):
|
127 |
+
"""Bound `z`, an array of shape (..., d)."""
|
128 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
129 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
130 |
+
shift = (offset / half_l).atanh()
|
131 |
+
return (z + shift).tanh() * half_l - offset
|
132 |
+
|
133 |
+
def quantize(self, z):
|
134 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
135 |
+
quantized = round_ste(self.bound(z))
|
136 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
137 |
+
return quantized / half_width
|
138 |
+
|
139 |
+
def _scale_and_shift(self, zhat_normalized):
|
140 |
+
half_width = self._levels // 2
|
141 |
+
return (zhat_normalized * half_width) + half_width
|
142 |
+
|
143 |
+
def _scale_and_shift_inverse(self, zhat):
|
144 |
+
half_width = self._levels // 2
|
145 |
+
return (zhat - half_width) / half_width
|
146 |
+
|
147 |
+
def _indices_to_codes(self, indices):
|
148 |
+
level_indices = self.indices_to_level_indices(indices)
|
149 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
150 |
+
return codes
|
151 |
+
|
152 |
+
def codes_to_indices(self, zhat):
|
153 |
+
"""Converts a `code` to an index in the codebook."""
|
154 |
+
assert zhat.shape[-1] == self.codebook_dim
|
155 |
+
zhat = self._scale_and_shift(zhat)
|
156 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
157 |
+
|
158 |
+
def indices_to_level_indices(self, indices):
|
159 |
+
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
160 |
+
indices = rearrange(indices, "... -> ... 1")
|
161 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
162 |
+
return codes_non_centered
|
163 |
+
|
164 |
+
def indices_to_codes(self, indices):
|
165 |
+
"""Inverse of `codes_to_indices`."""
|
166 |
+
assert exists(indices)
|
167 |
+
|
168 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
169 |
+
|
170 |
+
codes = self._indices_to_codes(indices)
|
171 |
+
|
172 |
+
if self.keep_num_codebooks_dim:
|
173 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
174 |
+
|
175 |
+
codes = self.project_out(codes)
|
176 |
+
|
177 |
+
if is_img_or_video or self.channel_first:
|
178 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
179 |
+
|
180 |
+
return codes
|
181 |
+
|
182 |
+
def forward(self, z):
|
183 |
+
"""
|
184 |
+
einstein notation
|
185 |
+
b - batch
|
186 |
+
n - sequence (or flattened spatial dimensions)
|
187 |
+
d - feature dimension
|
188 |
+
c - number of codebook dim
|
189 |
+
"""
|
190 |
+
|
191 |
+
is_img_or_video = z.ndim >= 4
|
192 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
193 |
+
|
194 |
+
# standardize image or video into (batch, seq, dimension)
|
195 |
+
|
196 |
+
if need_move_channel_last:
|
197 |
+
z = rearrange(z, "b d ... -> b ... d")
|
198 |
+
z, ps = pack_one(z, "b * d")
|
199 |
+
|
200 |
+
assert (
|
201 |
+
z.shape[-1] == self.dim
|
202 |
+
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
203 |
+
|
204 |
+
z = self.project_in(z)
|
205 |
+
|
206 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
207 |
+
|
208 |
+
# whether to force quantization step to be full precision or not
|
209 |
+
|
210 |
+
force_f32 = self.force_quantization_f32
|
211 |
+
quantization_context = (
|
212 |
+
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
|
213 |
+
)
|
214 |
+
|
215 |
+
with quantization_context():
|
216 |
+
orig_dtype = z.dtype
|
217 |
+
|
218 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
219 |
+
z = z.float()
|
220 |
+
|
221 |
+
codes = self.quantize(z)
|
222 |
+
|
223 |
+
# returning indices could be optional
|
224 |
+
|
225 |
+
indices = None
|
226 |
+
|
227 |
+
if self.return_indices:
|
228 |
+
indices = self.codes_to_indices(codes)
|
229 |
+
|
230 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
231 |
+
|
232 |
+
codes = codes.type(orig_dtype)
|
233 |
+
|
234 |
+
# project out
|
235 |
+
|
236 |
+
out = self.project_out(codes)
|
237 |
+
|
238 |
+
# reconstitute image or video dimensions
|
239 |
+
|
240 |
+
if need_move_channel_last:
|
241 |
+
out = unpack_one(out, ps, "b * d")
|
242 |
+
out = rearrange(out, "b ... d -> b d ...")
|
243 |
+
|
244 |
+
indices = maybe(unpack_one)(indices, ps, "b * c")
|
245 |
+
|
246 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
247 |
+
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
248 |
+
|
249 |
+
# return quantized output and indices
|
250 |
+
|
251 |
+
return out, indices
|
sparktts/modules/fsq/residual_fsq.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import Module
|
9 |
+
from torch.amp import autocast
|
10 |
+
from einx import get_at
|
11 |
+
from einops import rearrange, reduce, pack, unpack
|
12 |
+
|
13 |
+
from sparktts.modules.fsq.finite_scalar_quantization import FSQ
|
14 |
+
|
15 |
+
|
16 |
+
def exists(val):
|
17 |
+
return val is not None
|
18 |
+
|
19 |
+
|
20 |
+
def first(l):
|
21 |
+
return l[0]
|
22 |
+
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
return val if exists(val) else d
|
26 |
+
|
27 |
+
|
28 |
+
def round_up_multiple(num, mult):
|
29 |
+
return ceil(num / mult) * mult
|
30 |
+
|
31 |
+
|
32 |
+
# distributed helpers
|
33 |
+
|
34 |
+
|
35 |
+
def is_distributed():
|
36 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
37 |
+
|
38 |
+
|
39 |
+
def get_maybe_sync_seed(device, max_size=10_000):
|
40 |
+
rand_int = torch.randint(0, max_size, (), device=device)
|
41 |
+
|
42 |
+
if is_distributed():
|
43 |
+
dist.all_reduce(rand_int)
|
44 |
+
|
45 |
+
return rand_int.item()
|
46 |
+
|
47 |
+
|
48 |
+
class ResidualFSQ(Module):
|
49 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
*,
|
54 |
+
levels: List[int],
|
55 |
+
num_quantizers,
|
56 |
+
dim=None,
|
57 |
+
is_channel_first=False,
|
58 |
+
quantize_dropout=False,
|
59 |
+
quantize_dropout_cutoff_index=0,
|
60 |
+
quantize_dropout_multiple_of=1,
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
codebook_dim = len(levels)
|
65 |
+
dim = default(dim, codebook_dim)
|
66 |
+
|
67 |
+
requires_projection = codebook_dim != dim
|
68 |
+
self.project_in = (
|
69 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
70 |
+
)
|
71 |
+
self.project_out = (
|
72 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
73 |
+
)
|
74 |
+
self.has_projections = requires_projection
|
75 |
+
|
76 |
+
self.is_channel_first = is_channel_first
|
77 |
+
self.num_quantizers = num_quantizers
|
78 |
+
|
79 |
+
self.levels = levels
|
80 |
+
self.layers = nn.ModuleList([])
|
81 |
+
|
82 |
+
levels_tensor = torch.Tensor(levels)
|
83 |
+
|
84 |
+
scales = []
|
85 |
+
|
86 |
+
for ind in range(num_quantizers):
|
87 |
+
scales.append((levels_tensor - 1) ** -ind)
|
88 |
+
|
89 |
+
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
|
90 |
+
|
91 |
+
self.layers.append(fsq)
|
92 |
+
|
93 |
+
assert all([not fsq.has_projections for fsq in self.layers])
|
94 |
+
|
95 |
+
self.codebook_size = self.layers[0].codebook_size
|
96 |
+
|
97 |
+
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
98 |
+
|
99 |
+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
100 |
+
|
101 |
+
assert quantize_dropout_cutoff_index >= 0
|
102 |
+
|
103 |
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
104 |
+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
105 |
+
|
106 |
+
@property
|
107 |
+
def codebooks(self):
|
108 |
+
codebooks = [layer.implicit_codebook for layer in self.layers]
|
109 |
+
codebooks = torch.stack(codebooks, dim=0)
|
110 |
+
return codebooks
|
111 |
+
|
112 |
+
def get_codes_from_indices(self, indices):
|
113 |
+
|
114 |
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
115 |
+
|
116 |
+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
117 |
+
|
118 |
+
indices, ps = pack([indices], "b * q")
|
119 |
+
|
120 |
+
# because of quantize dropout, one can pass in indices that are coarse
|
121 |
+
# and the network should be able to reconstruct
|
122 |
+
|
123 |
+
if quantize_dim < self.num_quantizers:
|
124 |
+
assert (
|
125 |
+
self.quantize_dropout > 0.0
|
126 |
+
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
127 |
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
128 |
+
|
129 |
+
# take care of quantizer dropout
|
130 |
+
|
131 |
+
mask = indices == -1
|
132 |
+
indices = indices.masked_fill(
|
133 |
+
mask, 0
|
134 |
+
) # have it fetch a dummy code to be masked out later
|
135 |
+
|
136 |
+
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
|
137 |
+
|
138 |
+
# mask out any codes that were dropout-ed
|
139 |
+
|
140 |
+
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
|
141 |
+
|
142 |
+
# scale the codes
|
143 |
+
|
144 |
+
scales = rearrange(self.scales, "q d -> q 1 1 d")
|
145 |
+
all_codes = all_codes * scales
|
146 |
+
|
147 |
+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
148 |
+
|
149 |
+
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
150 |
+
|
151 |
+
return all_codes
|
152 |
+
|
153 |
+
def get_output_from_indices(self, indices):
|
154 |
+
codes = self.get_codes_from_indices(indices)
|
155 |
+
codes_summed = reduce(codes, "q ... -> ...", "sum")
|
156 |
+
return self.project_out(codes_summed)
|
157 |
+
|
158 |
+
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
|
159 |
+
num_quant, quant_dropout_multiple_of, device = (
|
160 |
+
self.num_quantizers,
|
161 |
+
self.quantize_dropout_multiple_of,
|
162 |
+
x.device,
|
163 |
+
)
|
164 |
+
|
165 |
+
# handle channel first
|
166 |
+
|
167 |
+
if self.is_channel_first:
|
168 |
+
x = rearrange(x, "b d ... -> b ... d")
|
169 |
+
x, ps = pack([x], "b * d")
|
170 |
+
|
171 |
+
# maybe project in
|
172 |
+
|
173 |
+
x = self.project_in(x)
|
174 |
+
|
175 |
+
quantized_out = 0.0
|
176 |
+
residual = x
|
177 |
+
|
178 |
+
all_indices = []
|
179 |
+
|
180 |
+
should_quantize_dropout = self.training and self.quantize_dropout
|
181 |
+
|
182 |
+
# sample a layer index at which to dropout further residual quantization
|
183 |
+
# also prepare null indices
|
184 |
+
|
185 |
+
if should_quantize_dropout:
|
186 |
+
|
187 |
+
# check if seed is manually passed in
|
188 |
+
|
189 |
+
if not exists(rand_quantize_dropout_fixed_seed):
|
190 |
+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
|
191 |
+
|
192 |
+
rand = random.Random(rand_quantize_dropout_fixed_seed)
|
193 |
+
|
194 |
+
rand_quantize_dropout_index = rand.randrange(
|
195 |
+
self.quantize_dropout_cutoff_index, num_quant
|
196 |
+
)
|
197 |
+
|
198 |
+
if quant_dropout_multiple_of != 1:
|
199 |
+
rand_quantize_dropout_index = (
|
200 |
+
round_up_multiple(
|
201 |
+
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
202 |
+
)
|
203 |
+
- 1
|
204 |
+
)
|
205 |
+
|
206 |
+
null_indices = torch.full(
|
207 |
+
x.shape[:2], -1.0, device=device, dtype=torch.long
|
208 |
+
)
|
209 |
+
|
210 |
+
# go through the layers
|
211 |
+
|
212 |
+
with autocast("cuda", enabled=False):
|
213 |
+
for quantizer_index, (layer, scale) in enumerate(
|
214 |
+
zip(self.layers, self.scales)
|
215 |
+
):
|
216 |
+
|
217 |
+
if (
|
218 |
+
should_quantize_dropout
|
219 |
+
and quantizer_index > rand_quantize_dropout_index
|
220 |
+
):
|
221 |
+
all_indices.append(null_indices)
|
222 |
+
continue
|
223 |
+
|
224 |
+
quantized, indices = layer(residual / scale)
|
225 |
+
|
226 |
+
quantized = quantized * scale
|
227 |
+
|
228 |
+
residual = residual - quantized.detach()
|
229 |
+
quantized_out = quantized_out + quantized
|
230 |
+
|
231 |
+
all_indices.append(indices)
|
232 |
+
|
233 |
+
# project out, if needed
|
234 |
+
|
235 |
+
quantized_out = self.project_out(quantized_out)
|
236 |
+
|
237 |
+
# stack all indices
|
238 |
+
|
239 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
240 |
+
|
241 |
+
# channel first out
|
242 |
+
|
243 |
+
if self.is_channel_first:
|
244 |
+
(quantized_out,) = unpack(quantized_out, ps, "b * d")
|
245 |
+
(all_indices,) = unpack(all_indices, ps, "b * d")
|
246 |
+
|
247 |
+
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
|
248 |
+
all_indices = rearrange(all_indices, "b ... d -> b d ...")
|
249 |
+
|
250 |
+
# return
|
251 |
+
|
252 |
+
ret = (quantized_out, all_indices)
|
253 |
+
|
254 |
+
if not return_all_codes:
|
255 |
+
return ret
|
256 |
+
|
257 |
+
# whether to return all codes from all codebooks across layers
|
258 |
+
|
259 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
260 |
+
|
261 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
262 |
+
|
263 |
+
return (*ret, all_codes)
|
264 |
+
|
265 |
+
|
266 |
+
# grouped residual fsq
|
267 |
+
|
268 |
+
|
269 |
+
class GroupedResidualFSQ(Module):
|
270 |
+
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
271 |
+
super().__init__()
|
272 |
+
self.dim = dim
|
273 |
+
self.groups = groups
|
274 |
+
assert (dim % groups) == 0
|
275 |
+
dim_per_group = dim // groups
|
276 |
+
|
277 |
+
self.accept_image_fmap = accept_image_fmap
|
278 |
+
|
279 |
+
self.rvqs = nn.ModuleList([])
|
280 |
+
|
281 |
+
for _ in range(groups):
|
282 |
+
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
|
283 |
+
|
284 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
285 |
+
|
286 |
+
@property
|
287 |
+
def codebooks(self):
|
288 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
289 |
+
|
290 |
+
@property
|
291 |
+
def split_dim(self):
|
292 |
+
return 1 if self.accept_image_fmap else -1
|
293 |
+
|
294 |
+
def get_codes_from_indices(self, indices):
|
295 |
+
codes = tuple(
|
296 |
+
rvq.get_codes_from_indices(chunk_indices)
|
297 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
298 |
+
)
|
299 |
+
return torch.stack(codes)
|
300 |
+
|
301 |
+
def get_output_from_indices(self, indices):
|
302 |
+
outputs = tuple(
|
303 |
+
rvq.get_output_from_indices(chunk_indices)
|
304 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
305 |
+
)
|
306 |
+
return torch.cat(outputs, dim=self.split_dim)
|
307 |
+
|
308 |
+
def forward(self, x, return_all_codes=False):
|
309 |
+
shape, split_dim, device = x.shape, self.split_dim, x.device
|
310 |
+
assert shape[split_dim] == self.dim
|
311 |
+
|
312 |
+
# split the feature dimension into groups
|
313 |
+
|
314 |
+
x = x.chunk(self.groups, dim=split_dim)
|
315 |
+
|
316 |
+
forward_kwargs = dict(
|
317 |
+
return_all_codes=return_all_codes,
|
318 |
+
rand_quantize_dropout_fixed_seed=(
|
319 |
+
get_maybe_sync_seed(device) if self.training else None
|
320 |
+
),
|
321 |
+
)
|
322 |
+
|
323 |
+
# invoke residual vq on each group
|
324 |
+
|
325 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
326 |
+
out = tuple(zip(*out))
|
327 |
+
|
328 |
+
# otherwise, get all the zipped outputs and combine them
|
329 |
+
|
330 |
+
quantized, all_indices, *maybe_all_codes = out
|
331 |
+
|
332 |
+
quantized = torch.cat(quantized, dim=split_dim)
|
333 |
+
all_indices = torch.stack(all_indices)
|
334 |
+
|
335 |
+
ret = (quantized, all_indices, *maybe_all_codes)
|
336 |
+
return ret
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
model = ResidualFSQ(
|
341 |
+
levels=[4, 4, 4, 4, 4, 4],
|
342 |
+
num_quantizers=1,
|
343 |
+
dim=30,
|
344 |
+
is_channel_first=True,
|
345 |
+
quantize_dropout=False,
|
346 |
+
)
|
347 |
+
x = torch.randn(2, 30, 10)
|
348 |
+
quantize, embed_ind = model(x)
|
349 |
+
|
350 |
+
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
|
351 |
+
|
352 |
+
print(quantize == emb_from_ind.transpose(1, 2))
|
353 |
+
|
354 |
+
print("quantize shape", quantize.shape)
|
355 |
+
print("embed_ind", embed_ind)
|
sparktts/modules/speaker/ecapa_tdnn.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Zhengyang Chen ([email protected])
|
2 |
+
# 2022 Hongji Wang ([email protected])
|
3 |
+
# 2023 Bing Han ([email protected])
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
""" This implementation is adapted from github repo:
|
18 |
+
https://github.com/lawlict/ECAPA-TDNN.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
import sparktts.modules.speaker.pooling_layers as pooling_layers
|
26 |
+
|
27 |
+
|
28 |
+
class Res2Conv1dReluBn(nn.Module):
|
29 |
+
"""
|
30 |
+
in_channels == out_channels == channels
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
channels,
|
36 |
+
kernel_size=1,
|
37 |
+
stride=1,
|
38 |
+
padding=0,
|
39 |
+
dilation=1,
|
40 |
+
bias=True,
|
41 |
+
scale=4,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
45 |
+
self.scale = scale
|
46 |
+
self.width = channels // scale
|
47 |
+
self.nums = scale if scale == 1 else scale - 1
|
48 |
+
|
49 |
+
self.convs = []
|
50 |
+
self.bns = []
|
51 |
+
for i in range(self.nums):
|
52 |
+
self.convs.append(
|
53 |
+
nn.Conv1d(
|
54 |
+
self.width,
|
55 |
+
self.width,
|
56 |
+
kernel_size,
|
57 |
+
stride,
|
58 |
+
padding,
|
59 |
+
dilation,
|
60 |
+
bias=bias,
|
61 |
+
)
|
62 |
+
)
|
63 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
64 |
+
self.convs = nn.ModuleList(self.convs)
|
65 |
+
self.bns = nn.ModuleList(self.bns)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = []
|
69 |
+
spx = torch.split(x, self.width, 1)
|
70 |
+
sp = spx[0]
|
71 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
72 |
+
# Order: conv -> relu -> bn
|
73 |
+
if i >= 1:
|
74 |
+
sp = sp + spx[i]
|
75 |
+
sp = conv(sp)
|
76 |
+
sp = bn(F.relu(sp))
|
77 |
+
out.append(sp)
|
78 |
+
if self.scale != 1:
|
79 |
+
out.append(spx[self.nums])
|
80 |
+
out = torch.cat(out, dim=1)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
""" Conv1d + BatchNorm1d + ReLU
|
86 |
+
"""
|
87 |
+
|
88 |
+
|
89 |
+
class Conv1dReluBn(nn.Module):
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
in_channels,
|
94 |
+
out_channels,
|
95 |
+
kernel_size=1,
|
96 |
+
stride=1,
|
97 |
+
padding=0,
|
98 |
+
dilation=1,
|
99 |
+
bias=True,
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
self.conv = nn.Conv1d(
|
103 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
104 |
+
)
|
105 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return self.bn(F.relu(self.conv(x)))
|
109 |
+
|
110 |
+
|
111 |
+
""" The SE connection of 1D case.
|
112 |
+
"""
|
113 |
+
|
114 |
+
|
115 |
+
class SE_Connect(nn.Module):
|
116 |
+
|
117 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
118 |
+
super().__init__()
|
119 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
120 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
out = x.mean(dim=2)
|
124 |
+
out = F.relu(self.linear1(out))
|
125 |
+
out = torch.sigmoid(self.linear2(out))
|
126 |
+
out = x * out.unsqueeze(2)
|
127 |
+
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
132 |
+
"""
|
133 |
+
|
134 |
+
|
135 |
+
class SE_Res2Block(nn.Module):
|
136 |
+
|
137 |
+
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
|
138 |
+
super().__init__()
|
139 |
+
self.se_res2block = nn.Sequential(
|
140 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
141 |
+
Res2Conv1dReluBn(
|
142 |
+
channels, kernel_size, stride, padding, dilation, scale=scale
|
143 |
+
),
|
144 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
145 |
+
SE_Connect(channels),
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return x + self.se_res2block(x)
|
150 |
+
|
151 |
+
|
152 |
+
class ECAPA_TDNN(nn.Module):
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
channels=512,
|
157 |
+
feat_dim=80,
|
158 |
+
embed_dim=192,
|
159 |
+
pooling_func="ASTP",
|
160 |
+
global_context_att=False,
|
161 |
+
emb_bn=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
|
166 |
+
self.layer2 = SE_Res2Block(
|
167 |
+
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
|
168 |
+
)
|
169 |
+
self.layer3 = SE_Res2Block(
|
170 |
+
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
|
171 |
+
)
|
172 |
+
self.layer4 = SE_Res2Block(
|
173 |
+
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
|
174 |
+
)
|
175 |
+
|
176 |
+
cat_channels = channels * 3
|
177 |
+
out_channels = 512 * 3
|
178 |
+
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
|
179 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
180 |
+
in_dim=out_channels, global_context_att=global_context_att
|
181 |
+
)
|
182 |
+
self.pool_out_dim = self.pool.get_out_dim()
|
183 |
+
self.bn = nn.BatchNorm1d(self.pool_out_dim)
|
184 |
+
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
|
185 |
+
self.emb_bn = emb_bn
|
186 |
+
if emb_bn: # better in SSL for SV
|
187 |
+
self.bn2 = nn.BatchNorm1d(embed_dim)
|
188 |
+
else:
|
189 |
+
self.bn2 = nn.Identity()
|
190 |
+
|
191 |
+
def forward(self, x, return_latent=False):
|
192 |
+
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
193 |
+
|
194 |
+
out1 = self.layer1(x)
|
195 |
+
out2 = self.layer2(out1)
|
196 |
+
out3 = self.layer3(out2)
|
197 |
+
out4 = self.layer4(out3)
|
198 |
+
|
199 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
200 |
+
latent = F.relu(self.conv(out))
|
201 |
+
out = self.bn(self.pool(latent))
|
202 |
+
out = self.linear(out)
|
203 |
+
if self.emb_bn:
|
204 |
+
out = self.bn2(out)
|
205 |
+
|
206 |
+
if return_latent:
|
207 |
+
return out, latent
|
208 |
+
return out
|
209 |
+
|
210 |
+
|
211 |
+
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
212 |
+
return ECAPA_TDNN(
|
213 |
+
channels=1024,
|
214 |
+
feat_dim=feat_dim,
|
215 |
+
embed_dim=embed_dim,
|
216 |
+
pooling_func=pooling_func,
|
217 |
+
emb_bn=emb_bn,
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
222 |
+
return ECAPA_TDNN(
|
223 |
+
channels=1024,
|
224 |
+
feat_dim=feat_dim,
|
225 |
+
embed_dim=embed_dim,
|
226 |
+
pooling_func=pooling_func,
|
227 |
+
global_context_att=True,
|
228 |
+
emb_bn=emb_bn,
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
233 |
+
return ECAPA_TDNN(
|
234 |
+
channels=512,
|
235 |
+
feat_dim=feat_dim,
|
236 |
+
embed_dim=embed_dim,
|
237 |
+
pooling_func=pooling_func,
|
238 |
+
emb_bn=emb_bn,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
243 |
+
return ECAPA_TDNN(
|
244 |
+
channels=512,
|
245 |
+
feat_dim=feat_dim,
|
246 |
+
embed_dim=embed_dim,
|
247 |
+
pooling_func=pooling_func,
|
248 |
+
global_context_att=True,
|
249 |
+
emb_bn=emb_bn,
|
250 |
+
)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
x = torch.zeros(1, 200, 100)
|
255 |
+
model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
|
256 |
+
model.eval()
|
257 |
+
out, latent = model(x, True)
|
258 |
+
print(out.shape)
|
259 |
+
print(latent.shape)
|
260 |
+
|
261 |
+
num_params = sum(param.numel() for param in model.parameters())
|
262 |
+
print("{} M".format(num_params / 1e6))
|
263 |
+
|
264 |
+
# from thop import profile
|
265 |
+
# x_np = torch.randn(1, 200, 80)
|
266 |
+
# flops, params = profile(model, inputs=(x_np, ))
|
267 |
+
# print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
|
sparktts/modules/speaker/perceiver_encoder.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
+
|
18 |
+
from collections import namedtuple
|
19 |
+
from functools import wraps
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from einops import rearrange, repeat
|
24 |
+
from einops.layers.torch import Rearrange
|
25 |
+
from packaging import version
|
26 |
+
from torch import einsum, nn
|
27 |
+
|
28 |
+
|
29 |
+
def exists(val):
|
30 |
+
return val is not None
|
31 |
+
|
32 |
+
|
33 |
+
def once(fn):
|
34 |
+
called = False
|
35 |
+
|
36 |
+
@wraps(fn)
|
37 |
+
def inner(x):
|
38 |
+
nonlocal called
|
39 |
+
if called:
|
40 |
+
return
|
41 |
+
called = True
|
42 |
+
return fn(x)
|
43 |
+
|
44 |
+
return inner
|
45 |
+
|
46 |
+
|
47 |
+
print_once = once(print)
|
48 |
+
|
49 |
+
# main class
|
50 |
+
|
51 |
+
|
52 |
+
class Attend(nn.Module):
|
53 |
+
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
54 |
+
super().__init__()
|
55 |
+
self.dropout = dropout
|
56 |
+
self.attn_dropout = nn.Dropout(dropout)
|
57 |
+
|
58 |
+
self.causal = causal
|
59 |
+
self.register_buffer("mask", None, persistent=False)
|
60 |
+
|
61 |
+
self.use_flash = use_flash
|
62 |
+
assert not (
|
63 |
+
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
64 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
65 |
+
|
66 |
+
# determine efficient attention configs for cuda and cpu
|
67 |
+
self.config = namedtuple(
|
68 |
+
"EfficientAttentionConfig",
|
69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
70 |
+
)
|
71 |
+
self.cpu_config = self.config(True, True, True)
|
72 |
+
self.cuda_config = None
|
73 |
+
|
74 |
+
if not torch.cuda.is_available() or not use_flash:
|
75 |
+
return
|
76 |
+
|
77 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
78 |
+
|
79 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
80 |
+
print_once(
|
81 |
+
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
82 |
+
)
|
83 |
+
self.cuda_config = self.config(True, False, False)
|
84 |
+
else:
|
85 |
+
print_once(
|
86 |
+
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
87 |
+
)
|
88 |
+
self.cuda_config = self.config(False, True, True)
|
89 |
+
|
90 |
+
def get_mask(self, n, device):
|
91 |
+
if exists(self.mask) and self.mask.shape[-1] >= n:
|
92 |
+
return self.mask[:n, :n]
|
93 |
+
|
94 |
+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
95 |
+
self.register_buffer("mask", mask, persistent=False)
|
96 |
+
return mask
|
97 |
+
|
98 |
+
def flash_attn(self, q, k, v, mask=None):
|
99 |
+
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
100 |
+
|
101 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
102 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
103 |
+
|
104 |
+
if k.ndim == 3:
|
105 |
+
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
106 |
+
|
107 |
+
if v.ndim == 3:
|
108 |
+
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
109 |
+
|
110 |
+
# Check if mask exists and expand to compatible shape
|
111 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
112 |
+
|
113 |
+
if exists(mask):
|
114 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
115 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
116 |
+
|
117 |
+
# Check if there is a compatible device for flash attention
|
118 |
+
|
119 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
120 |
+
|
121 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
122 |
+
|
123 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
124 |
+
out = F.scaled_dot_product_attention(
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
attn_mask=mask,
|
129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
130 |
+
is_causal=self.causal,
|
131 |
+
)
|
132 |
+
|
133 |
+
return out
|
134 |
+
|
135 |
+
def forward(self, q, k, v, mask=None):
|
136 |
+
"""
|
137 |
+
einstein notation
|
138 |
+
b - batch
|
139 |
+
h - heads
|
140 |
+
n, i, j - sequence length (base sequence length, source, target)
|
141 |
+
d - feature dimension
|
142 |
+
"""
|
143 |
+
|
144 |
+
n, device = q.shape[-2], q.device
|
145 |
+
|
146 |
+
scale = q.shape[-1] ** -0.5
|
147 |
+
|
148 |
+
if self.use_flash:
|
149 |
+
return self.flash_attn(q, k, v, mask=mask)
|
150 |
+
|
151 |
+
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
152 |
+
|
153 |
+
# similarity
|
154 |
+
|
155 |
+
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
156 |
+
|
157 |
+
# key padding mask
|
158 |
+
|
159 |
+
if exists(mask):
|
160 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
161 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
162 |
+
|
163 |
+
# causal mask
|
164 |
+
|
165 |
+
if self.causal:
|
166 |
+
causal_mask = self.get_mask(n, device)
|
167 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
168 |
+
|
169 |
+
# attention
|
170 |
+
|
171 |
+
attn = sim.softmax(dim=-1)
|
172 |
+
attn = self.attn_dropout(attn)
|
173 |
+
|
174 |
+
# aggregate values
|
175 |
+
|
176 |
+
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
177 |
+
|
178 |
+
return out
|
179 |
+
|
180 |
+
|
181 |
+
def Sequential(*mods):
|
182 |
+
return nn.Sequential(*filter(exists, mods))
|
183 |
+
|
184 |
+
|
185 |
+
def exists(x):
|
186 |
+
return x is not None
|
187 |
+
|
188 |
+
|
189 |
+
def default(val, d):
|
190 |
+
if exists(val):
|
191 |
+
return val
|
192 |
+
return d() if callable(d) else d
|
193 |
+
|
194 |
+
|
195 |
+
class RMSNorm(nn.Module):
|
196 |
+
def __init__(self, dim, scale=True, dim_cond=None):
|
197 |
+
super().__init__()
|
198 |
+
self.cond = exists(dim_cond)
|
199 |
+
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
200 |
+
|
201 |
+
self.scale = dim**0.5
|
202 |
+
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
203 |
+
|
204 |
+
def forward(self, x, cond=None):
|
205 |
+
gamma = default(self.gamma, 1)
|
206 |
+
out = F.normalize(x, dim=-1) * self.scale * gamma
|
207 |
+
|
208 |
+
if not self.cond:
|
209 |
+
return out
|
210 |
+
|
211 |
+
assert exists(cond)
|
212 |
+
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
213 |
+
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
214 |
+
return out * gamma + beta
|
215 |
+
|
216 |
+
|
217 |
+
class CausalConv1d(nn.Conv1d):
|
218 |
+
def __init__(self, *args, **kwargs):
|
219 |
+
super().__init__(*args, **kwargs)
|
220 |
+
(kernel_size,) = self.kernel_size
|
221 |
+
(dilation,) = self.dilation
|
222 |
+
(stride,) = self.stride
|
223 |
+
|
224 |
+
assert stride == 1
|
225 |
+
self.causal_padding = dilation * (kernel_size - 1)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
229 |
+
return super().forward(causal_padded_x)
|
230 |
+
|
231 |
+
|
232 |
+
class GEGLU(nn.Module):
|
233 |
+
def forward(self, x):
|
234 |
+
x, gate = x.chunk(2, dim=-1)
|
235 |
+
return F.gelu(gate) * x
|
236 |
+
|
237 |
+
|
238 |
+
def FeedForward(dim, mult=4, causal_conv=False):
|
239 |
+
dim_inner = int(dim * mult * 2 / 3)
|
240 |
+
|
241 |
+
conv = None
|
242 |
+
if causal_conv:
|
243 |
+
conv = nn.Sequential(
|
244 |
+
Rearrange("b n d -> b d n"),
|
245 |
+
CausalConv1d(dim_inner, dim_inner, 3),
|
246 |
+
Rearrange("b d n -> b n d"),
|
247 |
+
)
|
248 |
+
|
249 |
+
return Sequential(
|
250 |
+
nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
class Attention(nn.Module):
|
255 |
+
def __init__(
|
256 |
+
self,
|
257 |
+
dim,
|
258 |
+
*,
|
259 |
+
dim_context=None,
|
260 |
+
causal=False,
|
261 |
+
dim_head=64,
|
262 |
+
heads=8,
|
263 |
+
dropout=0.0,
|
264 |
+
use_flash=False,
|
265 |
+
cross_attn_include_queries=False,
|
266 |
+
):
|
267 |
+
super().__init__()
|
268 |
+
self.scale = dim_head**-0.5
|
269 |
+
self.heads = heads
|
270 |
+
self.cross_attn_include_queries = cross_attn_include_queries
|
271 |
+
|
272 |
+
dim_inner = dim_head * heads
|
273 |
+
dim_context = default(dim_context, dim)
|
274 |
+
|
275 |
+
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
276 |
+
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
277 |
+
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
278 |
+
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
279 |
+
|
280 |
+
def forward(self, x, context=None, mask=None):
|
281 |
+
h, has_context = self.heads, exists(context)
|
282 |
+
|
283 |
+
context = default(context, x)
|
284 |
+
|
285 |
+
if has_context and self.cross_attn_include_queries:
|
286 |
+
context = torch.cat((x, context), dim=-2)
|
287 |
+
|
288 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
289 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
290 |
+
|
291 |
+
out = self.attend(q, k, v, mask=mask)
|
292 |
+
|
293 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
294 |
+
return self.to_out(out)
|
295 |
+
|
296 |
+
|
297 |
+
class PerceiverResampler(nn.Module):
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
*,
|
301 |
+
dim,
|
302 |
+
depth=2,
|
303 |
+
dim_context=None,
|
304 |
+
num_latents=32,
|
305 |
+
dim_head=64,
|
306 |
+
heads=8,
|
307 |
+
ff_mult=4,
|
308 |
+
use_flash_attn=False,
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
dim_context = default(dim_context, dim)
|
312 |
+
|
313 |
+
self.proj_context = (
|
314 |
+
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
315 |
+
)
|
316 |
+
|
317 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
318 |
+
nn.init.normal_(self.latents, std=0.02)
|
319 |
+
|
320 |
+
self.layers = nn.ModuleList([])
|
321 |
+
for _ in range(depth):
|
322 |
+
self.layers.append(
|
323 |
+
nn.ModuleList(
|
324 |
+
[
|
325 |
+
Attention(
|
326 |
+
dim=dim,
|
327 |
+
dim_head=dim_head,
|
328 |
+
heads=heads,
|
329 |
+
use_flash=use_flash_attn,
|
330 |
+
cross_attn_include_queries=True,
|
331 |
+
),
|
332 |
+
FeedForward(dim=dim, mult=ff_mult),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
self.norm = RMSNorm(dim)
|
338 |
+
|
339 |
+
def forward(self, x, mask=None):
|
340 |
+
batch = x.shape[0]
|
341 |
+
|
342 |
+
x = self.proj_context(x)
|
343 |
+
|
344 |
+
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
345 |
+
|
346 |
+
for attn, ff in self.layers:
|
347 |
+
latents = attn(latents, x, mask=mask) + latents
|
348 |
+
latents = ff(latents) + latents
|
349 |
+
|
350 |
+
return self.norm(latents)
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
model = PerceiverResampler(dim=256, dim_context=80)
|
355 |
+
x = torch.randn(8, 200, 80)
|
356 |
+
out = model(x)
|
357 |
+
print(out.shape) # [8, 32, 80]
|
358 |
+
|
359 |
+
num_params = sum(param.numel() for param in model.parameters())
|
360 |
+
print("{} M".format(num_params / 1e6))
|
sparktts/modules/speaker/pooling_layers.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Shuai Wang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Pooling functions to aggregate frame-level deep features
|
16 |
+
into segment-level speaker embeddings
|
17 |
+
|
18 |
+
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
|
19 |
+
even though we remove the mean statistic, on Voxceleb.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
class TAP(nn.Module):
|
28 |
+
"""
|
29 |
+
Temporal average pooling, only first-order mean is considered
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, in_dim=0, **kwargs):
|
33 |
+
super(TAP, self).__init__()
|
34 |
+
self.in_dim = in_dim
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
pooling_mean = x.mean(dim=-1)
|
38 |
+
# To be compatable with 2D input
|
39 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
40 |
+
return pooling_mean
|
41 |
+
|
42 |
+
def get_out_dim(self):
|
43 |
+
self.out_dim = self.in_dim
|
44 |
+
return self.out_dim
|
45 |
+
|
46 |
+
|
47 |
+
class TSDP(nn.Module):
|
48 |
+
"""
|
49 |
+
Temporal standard deviation pooling, only second-order std is considered
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, in_dim=0, **kwargs):
|
53 |
+
super(TSDP, self).__init__()
|
54 |
+
self.in_dim = in_dim
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
# The last dimension is the temporal axis
|
58 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
59 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
60 |
+
return pooling_std
|
61 |
+
|
62 |
+
def get_out_dim(self):
|
63 |
+
self.out_dim = self.in_dim
|
64 |
+
return self.out_dim
|
65 |
+
|
66 |
+
|
67 |
+
class TSTP(nn.Module):
|
68 |
+
"""
|
69 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
70 |
+
x-vector
|
71 |
+
Comment: simple concatenation can not make full use of both statistics
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, in_dim=0, **kwargs):
|
75 |
+
super(TSTP, self).__init__()
|
76 |
+
self.in_dim = in_dim
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
# The last dimension is the temporal axis
|
80 |
+
pooling_mean = x.mean(dim=-1)
|
81 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
82 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
83 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
84 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
85 |
+
return stats
|
86 |
+
|
87 |
+
def get_out_dim(self):
|
88 |
+
self.out_dim = self.in_dim * 2
|
89 |
+
return self.out_dim
|
90 |
+
|
91 |
+
|
92 |
+
class ASTP(nn.Module):
|
93 |
+
""" Attentive statistics pooling: Channel- and context-dependent
|
94 |
+
statistics pooling, first used in ECAPA_TDNN.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self,
|
98 |
+
in_dim,
|
99 |
+
bottleneck_dim=128,
|
100 |
+
global_context_att=False,
|
101 |
+
**kwargs):
|
102 |
+
super(ASTP, self).__init__()
|
103 |
+
self.in_dim = in_dim
|
104 |
+
self.global_context_att = global_context_att
|
105 |
+
|
106 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
107 |
+
# need to transpose inputs.
|
108 |
+
if global_context_att:
|
109 |
+
self.linear1 = nn.Conv1d(
|
110 |
+
in_dim * 3, bottleneck_dim,
|
111 |
+
kernel_size=1) # equals W and b in the paper
|
112 |
+
else:
|
113 |
+
self.linear1 = nn.Conv1d(
|
114 |
+
in_dim, bottleneck_dim,
|
115 |
+
kernel_size=1) # equals W and b in the paper
|
116 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
117 |
+
kernel_size=1) # equals V and k in the paper
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
"""
|
121 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
122 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
123 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
124 |
+
"""
|
125 |
+
if len(x.shape) == 4:
|
126 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
127 |
+
assert len(x.shape) == 3
|
128 |
+
|
129 |
+
if self.global_context_att:
|
130 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
131 |
+
context_std = torch.sqrt(
|
132 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
|
133 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
134 |
+
else:
|
135 |
+
x_in = x
|
136 |
+
|
137 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
138 |
+
alpha = torch.tanh(
|
139 |
+
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
140 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
141 |
+
mean = torch.sum(alpha * x, dim=2)
|
142 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
143 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
144 |
+
return torch.cat([mean, std], dim=1)
|
145 |
+
|
146 |
+
def get_out_dim(self):
|
147 |
+
self.out_dim = 2 * self.in_dim
|
148 |
+
return self.out_dim
|
149 |
+
|
150 |
+
|
151 |
+
class MHASTP(torch.nn.Module):
|
152 |
+
""" Multi head attentive statistics pooling
|
153 |
+
Reference:
|
154 |
+
Self Multi-Head Attention for Speaker Recognition
|
155 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
in_dim,
|
160 |
+
layer_num=2,
|
161 |
+
head_num=2,
|
162 |
+
d_s=1,
|
163 |
+
bottleneck_dim=64,
|
164 |
+
**kwargs):
|
165 |
+
super(MHASTP, self).__init__()
|
166 |
+
assert (in_dim % head_num
|
167 |
+
) == 0 # make sure that head num can be divided by input_dim
|
168 |
+
self.in_dim = in_dim
|
169 |
+
self.head_num = head_num
|
170 |
+
d_model = int(in_dim / head_num)
|
171 |
+
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
|
172 |
+
if d_s > 1:
|
173 |
+
d_s = d_model
|
174 |
+
else:
|
175 |
+
d_s = 1
|
176 |
+
self.d_s = d_s
|
177 |
+
channel_dims[0], channel_dims[-1] = d_model, d_s
|
178 |
+
heads_att_trans = []
|
179 |
+
for i in range(self.head_num):
|
180 |
+
att_trans = nn.Sequential()
|
181 |
+
for i in range(layer_num - 1):
|
182 |
+
att_trans.add_module(
|
183 |
+
'att_' + str(i),
|
184 |
+
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
|
185 |
+
att_trans.add_module('tanh' + str(i), nn.Tanh())
|
186 |
+
att_trans.add_module(
|
187 |
+
'att_' + str(layer_num - 1),
|
188 |
+
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
|
189 |
+
1, 1))
|
190 |
+
heads_att_trans.append(att_trans)
|
191 |
+
self.heads_att_trans = nn.ModuleList(heads_att_trans)
|
192 |
+
|
193 |
+
def forward(self, input):
|
194 |
+
"""
|
195 |
+
input: a 3-dimensional tensor in xvector architecture
|
196 |
+
or a 4-dimensional tensor in resnet architecture
|
197 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
198 |
+
"""
|
199 |
+
if len(input.shape) == 4: # B x F x T
|
200 |
+
input = input.reshape(input.shape[0],
|
201 |
+
input.shape[1] * input.shape[2],
|
202 |
+
input.shape[3])
|
203 |
+
assert len(input.shape) == 3
|
204 |
+
bs, f_dim, t_dim = input.shape
|
205 |
+
chunks = torch.chunk(input, self.head_num, 1)
|
206 |
+
# split
|
207 |
+
chunks_out = []
|
208 |
+
# for i in range(self.head_num):
|
209 |
+
# att_score = self.heads_att_trans[i](chunks[i])
|
210 |
+
for i, layer in enumerate(self.heads_att_trans):
|
211 |
+
att_score = layer(chunks[i])
|
212 |
+
alpha = F.softmax(att_score, dim=-1)
|
213 |
+
mean = torch.sum(alpha * chunks[i], dim=2)
|
214 |
+
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
|
215 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
216 |
+
chunks_out.append(torch.cat((mean, std), dim=1))
|
217 |
+
out = torch.cat(chunks_out, dim=1)
|
218 |
+
return out
|
219 |
+
|
220 |
+
def get_out_dim(self):
|
221 |
+
self.out_dim = 2 * self.in_dim
|
222 |
+
return self.out_dim
|
223 |
+
|
224 |
+
|
225 |
+
class MQMHASTP(torch.nn.Module):
|
226 |
+
""" An attentive pooling
|
227 |
+
Reference:
|
228 |
+
multi query multi head attentive statistics pooling
|
229 |
+
https://arxiv.org/pdf/2110.05042.pdf
|
230 |
+
Args:
|
231 |
+
in_dim: the feature dimension of input
|
232 |
+
layer_num: the number of layer in the pooling layer
|
233 |
+
query_num: the number of querys
|
234 |
+
head_num: the number of heads
|
235 |
+
bottleneck_dim: the bottleneck dimension
|
236 |
+
|
237 |
+
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
|
238 |
+
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
|
239 |
+
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
|
240 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
241 |
+
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
|
242 |
+
https://arxiv.org/pdf/1803.10963.pdf
|
243 |
+
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
|
244 |
+
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self,
|
248 |
+
in_dim,
|
249 |
+
layer_num=2,
|
250 |
+
query_num=2,
|
251 |
+
head_num=8,
|
252 |
+
d_s=2,
|
253 |
+
bottleneck_dim=64,
|
254 |
+
**kwargs):
|
255 |
+
super(MQMHASTP, self).__init__()
|
256 |
+
self.n_query = nn.ModuleList([
|
257 |
+
MHASTP(in_dim,
|
258 |
+
layer_num=layer_num,
|
259 |
+
head_num=head_num,
|
260 |
+
d_s=d_s,
|
261 |
+
bottleneck_dim=bottleneck_dim) for i in range(query_num)
|
262 |
+
])
|
263 |
+
self.query_num = query_num
|
264 |
+
self.in_dim = in_dim
|
265 |
+
|
266 |
+
def forward(self, input):
|
267 |
+
"""
|
268 |
+
input: a 3-dimensional tensor in xvector architecture
|
269 |
+
or a 4-dimensional tensor in resnet architecture
|
270 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
271 |
+
"""
|
272 |
+
if len(input.shape) == 4: # B x F x T
|
273 |
+
input = input.reshape(input.shape[0],
|
274 |
+
input.shape[1] * input.shape[2],
|
275 |
+
input.shape[3])
|
276 |
+
assert len(input.shape) == 3
|
277 |
+
res = []
|
278 |
+
for i, layer in enumerate(self.n_query):
|
279 |
+
res.append(layer(input))
|
280 |
+
out = torch.cat(res, dim=-1)
|
281 |
+
return out
|
282 |
+
|
283 |
+
def get_out_dim(self):
|
284 |
+
self.out_dim = self.in_dim * 2 * self.query_num
|
285 |
+
return self.out_dim
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
data = torch.randn(16, 512, 10, 35)
|
290 |
+
# model = StatisticsPooling()
|
291 |
+
model = MQMHASTP(512 * 10)
|
292 |
+
model = MHASTP(512 * 10)
|
293 |
+
model = MQMHASTP(512 * 10, context=False)
|
294 |
+
print(model)
|
295 |
+
|
296 |
+
out = model(data)
|
297 |
+
print(out.shape)
|
298 |
+
print(model.get_out_dim())
|
sparktts/modules/speaker/speaker_encoder.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from typing import List, Tuple
|
20 |
+
from sparktts.modules.fsq.residual_fsq import ResidualFSQ
|
21 |
+
from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
|
22 |
+
from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
|
23 |
+
|
24 |
+
"""
|
25 |
+
x-vector + d-vector
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class SpeakerEncoder(nn.Module):
|
30 |
+
"""
|
31 |
+
|
32 |
+
Args:
|
33 |
+
input_dim (int): acoustic feature dimension
|
34 |
+
out_dim (int): output dimension of x-vector and d-vector
|
35 |
+
latent_dim (int): latent dimension before quantization
|
36 |
+
token_num (int): sequence length of speaker tokens
|
37 |
+
fsq_levels (List[int]): number of levels for each quantizer
|
38 |
+
fsq_num_quantizers (int): number of quantizers
|
39 |
+
|
40 |
+
Return:
|
41 |
+
speaker_embs: (B, T2, out_dim)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
input_dim: int = 100,
|
47 |
+
out_dim: int = 512,
|
48 |
+
latent_dim: int = 128,
|
49 |
+
token_num: int = 32,
|
50 |
+
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
|
51 |
+
fsq_num_quantizers: int = 1,
|
52 |
+
):
|
53 |
+
super(SpeakerEncoder, self).__init__()
|
54 |
+
|
55 |
+
self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
|
56 |
+
feat_dim=input_dim, embed_dim=out_dim
|
57 |
+
)
|
58 |
+
self.perceiver_sampler = PerceiverResampler(
|
59 |
+
dim=latent_dim, dim_context=512 * 3, num_latents=token_num
|
60 |
+
)
|
61 |
+
self.quantizer = ResidualFSQ(
|
62 |
+
levels=fsq_levels,
|
63 |
+
num_quantizers=fsq_num_quantizers,
|
64 |
+
dim=latent_dim,
|
65 |
+
is_channel_first=True,
|
66 |
+
quantize_dropout=False,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.project = nn.Linear(latent_dim * token_num, out_dim)
|
70 |
+
|
71 |
+
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
72 |
+
zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
|
73 |
+
return zq.transpose(1, 2)
|
74 |
+
|
75 |
+
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
|
76 |
+
mels = mels.transpose(1, 2)
|
77 |
+
x = self.perceiver_sampler(mels).transpose(1, 2)
|
78 |
+
zq, indices = self.quantizer(x)
|
79 |
+
return indices
|
80 |
+
|
81 |
+
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
mels: (B, D_mel, T1)
|
85 |
+
|
86 |
+
Return:
|
87 |
+
x_vector: (B, out_dim)
|
88 |
+
d_vector: (B, out_dim)
|
89 |
+
"""
|
90 |
+
# mels = mels.transpose(1,2)
|
91 |
+
|
92 |
+
x_vector, features = self.speaker_encoder(mels, True)
|
93 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
94 |
+
zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
|
95 |
+
x = zq.reshape(zq.shape[0], -1)
|
96 |
+
d_vector = self.project(x)
|
97 |
+
|
98 |
+
return x_vector, d_vector
|
99 |
+
|
100 |
+
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""tokenize the input mel spectrogram"""
|
102 |
+
_, features = self.speaker_encoder(mels, True)
|
103 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
104 |
+
zq, indices = self.quantizer(x)
|
105 |
+
return indices
|
106 |
+
|
107 |
+
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
|
108 |
+
"""detokenize the input indices to d-vector"""
|
109 |
+
zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
|
110 |
+
x = zq.reshape(zq.shape[0], -1)
|
111 |
+
d_vector = self.project(x)
|
112 |
+
return d_vector
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
model = SpeakerEncoder(
|
116 |
+
input_dim=100,
|
117 |
+
latent_dim=128,
|
118 |
+
token_num=32,
|
119 |
+
fsq_levels=[4, 4, 4, 4, 4, 4],
|
120 |
+
fsq_num_quantizers=1,
|
121 |
+
)
|
122 |
+
mel = torch.randn(8, 200, 100)
|
123 |
+
x_vector, d_vector = model(mel)
|
124 |
+
print("x-vector shape", x_vector.shape)
|
125 |
+
print("d-vector shape", d_vector.shape)
|
126 |
+
|
127 |
+
indices = model.tokenize(mel)
|
128 |
+
print("indices shape", indices.shape)
|
129 |
+
d_vector_post = model.detokenize(indices)
|
130 |
+
print("d-vector shape", d_vector_post.shape)
|
131 |
+
if d_vector_post.all() == d_vector.all():
|
132 |
+
print("d-vector post and d-vector are the same")
|
133 |
+
else:
|
134 |
+
print("d-vector post and d-vector are different")
|
135 |
+
num_params = sum(param.numel() for param in model.parameters())
|
136 |
+
print("{} M".format(num_params / 1e6))
|
sparktts/modules/vq/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
|
17 |
+
|
18 |
+
|
19 |
+
from typing import Any, Dict
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from einops import rearrange
|
25 |
+
from torch.nn.utils import weight_norm
|
26 |
+
|
27 |
+
|
28 |
+
def WNConv1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
def ema_inplace(moving_avg, new, decay):
|
33 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
34 |
+
|
35 |
+
|
36 |
+
class FactorizedVectorQuantize(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
input_dim: int,
|
40 |
+
codebook_size: int,
|
41 |
+
codebook_dim: int,
|
42 |
+
commitment: float,
|
43 |
+
codebook_loss_weight: float = 1.0,
|
44 |
+
decay: float = 0.99,
|
45 |
+
threshold_ema_dead_code: float = 2,
|
46 |
+
momentum: float = 0.99,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.input_dim = input_dim
|
51 |
+
self.codebook_size = codebook_size
|
52 |
+
self.codebook_dim = codebook_dim
|
53 |
+
self.commitment = commitment
|
54 |
+
self.codebook_loss_weight = codebook_loss_weight
|
55 |
+
self.decay = decay
|
56 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
57 |
+
self.momentum = momentum
|
58 |
+
|
59 |
+
if input_dim != self.codebook_dim:
|
60 |
+
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
|
61 |
+
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
|
62 |
+
|
63 |
+
else:
|
64 |
+
self.in_project = nn.Identity()
|
65 |
+
self.out_project = nn.Identity()
|
66 |
+
|
67 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
68 |
+
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
|
69 |
+
|
70 |
+
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
|
71 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
72 |
+
the corresponding codebook vectors
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
z : Tensor[B x D x T]
|
77 |
+
|
78 |
+
Returns
|
79 |
+
-------
|
80 |
+
Tensor[B x D x T]
|
81 |
+
Quantized continuous representation of input
|
82 |
+
Tensor[1]
|
83 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
84 |
+
entries
|
85 |
+
Tensor[1]
|
86 |
+
Codebook loss to update the codebook
|
87 |
+
Tensor[B x T]
|
88 |
+
Codebook indices (quantized discrete representation of input)
|
89 |
+
Tensor[B x D x T]
|
90 |
+
Projected latents (continuous representation of input before quantization)
|
91 |
+
"""
|
92 |
+
# transpose since we use linear
|
93 |
+
|
94 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
95 |
+
z_e = self.in_project(z)
|
96 |
+
z_q, indices, dists = self.decode_latents(z_e)
|
97 |
+
|
98 |
+
# statistic the usage of codes
|
99 |
+
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
|
100 |
+
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
|
101 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
102 |
+
|
103 |
+
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
|
104 |
+
if self.training:
|
105 |
+
# We do the expiry of code at that point as buffers are in sync
|
106 |
+
# and all the workers will take the same decision.
|
107 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
|
108 |
+
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
|
109 |
+
|
110 |
+
if self.training:
|
111 |
+
commit_loss = (
|
112 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
113 |
+
* self.commitment
|
114 |
+
)
|
115 |
+
|
116 |
+
codebook_loss = (
|
117 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
118 |
+
* self.codebook_loss_weight
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
commit_loss = torch.zeros(0, device=z.device)
|
123 |
+
codebook_loss = torch.zeros(0, device=z.device)
|
124 |
+
|
125 |
+
z_q = (
|
126 |
+
z_e + (z_q - z_e).detach()
|
127 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
128 |
+
|
129 |
+
z_q = self.out_project(z_q)
|
130 |
+
|
131 |
+
vq_loss = (commit_loss + codebook_loss).mean()
|
132 |
+
|
133 |
+
return {
|
134 |
+
"z_q": z_q,
|
135 |
+
"indices": indices,
|
136 |
+
"dists": dists,
|
137 |
+
"vq_loss": vq_loss,
|
138 |
+
"perplexity": perplexity,
|
139 |
+
"active_num": active_num.float(),
|
140 |
+
}
|
141 |
+
|
142 |
+
def vq2emb(self, vq, out_proj=True):
|
143 |
+
emb = self.embed_code(vq)
|
144 |
+
if out_proj:
|
145 |
+
emb = self.out_project(emb)
|
146 |
+
return emb
|
147 |
+
|
148 |
+
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""tokenize the input tensor"""
|
150 |
+
z_e = self.in_project(z)
|
151 |
+
_, indices, _ = self.decode_latents(z_e)
|
152 |
+
return indices
|
153 |
+
|
154 |
+
def detokenize(self, indices):
|
155 |
+
"""detokenize the input indices"""
|
156 |
+
z_q = self.decode_code(indices)
|
157 |
+
z_q = self.out_project(z_q)
|
158 |
+
return z_q
|
159 |
+
|
160 |
+
def get_emb(self):
|
161 |
+
return self.codebook.weight
|
162 |
+
|
163 |
+
def embed_code(self, embed_id):
|
164 |
+
return F.embedding(embed_id, self.codebook.weight)
|
165 |
+
|
166 |
+
def decode_code(self, embed_id):
|
167 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
168 |
+
|
169 |
+
def decode_latents(self, latents):
|
170 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
171 |
+
codebook = self.codebook.weight
|
172 |
+
|
173 |
+
# L2 normalize encodings and codebook
|
174 |
+
encodings = F.normalize(encodings)
|
175 |
+
codebook = F.normalize(codebook)
|
176 |
+
|
177 |
+
# Compute euclidean distance between encodings and codebook,
|
178 |
+
# with L2 normalization, the distance is equal to cosine distance
|
179 |
+
dist = (
|
180 |
+
encodings.pow(2).sum(1, keepdim=True)
|
181 |
+
- 2 * encodings @ codebook.t()
|
182 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
183 |
+
)
|
184 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
185 |
+
z_q = self.decode_code(indices)
|
186 |
+
|
187 |
+
return z_q, indices, dist
|
sparktts/utils/__init__.py
ADDED
File without changes
|
sparktts/utils/audio.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Description:
|
17 |
+
This script contains a collection of functions designed to handle various
|
18 |
+
audio processing.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import random
|
22 |
+
import soxr
|
23 |
+
import soundfile
|
24 |
+
import torch
|
25 |
+
import torchaudio
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Tuple
|
30 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
31 |
+
|
32 |
+
|
33 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
34 |
+
"""
|
35 |
+
Normalize the volume of an audio signal.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
audio (numpy array): Input audio signal array.
|
39 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
numpy array: The volume-normalized audio signal.
|
43 |
+
"""
|
44 |
+
# Sort the absolute values of the audio signal
|
45 |
+
temp = np.sort(np.abs(audio))
|
46 |
+
|
47 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
48 |
+
if temp[-1] < 0.1:
|
49 |
+
scaling_factor = max(
|
50 |
+
temp[-1], 1e-3
|
51 |
+
) # Prevent division by zero with a small constant
|
52 |
+
audio = audio / scaling_factor * 0.1
|
53 |
+
|
54 |
+
# Filter out values less than 0.01 from temp
|
55 |
+
temp = temp[temp > 0.01]
|
56 |
+
L = temp.shape[0] # Length of the filtered array
|
57 |
+
|
58 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
59 |
+
if L <= 10:
|
60 |
+
return audio
|
61 |
+
|
62 |
+
# Compute the average of the top 10% to 1% of values in temp
|
63 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
64 |
+
|
65 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
66 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
67 |
+
|
68 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
69 |
+
max_value = np.max(np.abs(audio))
|
70 |
+
if max_value > 1:
|
71 |
+
audio = audio / max_value
|
72 |
+
|
73 |
+
return audio
|
74 |
+
|
75 |
+
|
76 |
+
def load_audio(
|
77 |
+
adfile: Path,
|
78 |
+
sampling_rate: int = None,
|
79 |
+
length: int = None,
|
80 |
+
volume_normalize: bool = False,
|
81 |
+
segment_duration: int = None,
|
82 |
+
) -> np.ndarray:
|
83 |
+
r"""Load audio file with target sampling rate and lsength
|
84 |
+
|
85 |
+
Args:
|
86 |
+
adfile (Path): path to audio file.
|
87 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
88 |
+
length (int, optional): target audio length. Defaults to None.
|
89 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
90 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
91 |
+
Defualt to None which means the whole audio will be used.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
audio (np.ndarray): audio
|
95 |
+
"""
|
96 |
+
|
97 |
+
audio, sr = soundfile.read(adfile)
|
98 |
+
if len(audio.shape) > 1:
|
99 |
+
audio = audio[:, 0]
|
100 |
+
|
101 |
+
if sampling_rate is not None and sr != sampling_rate:
|
102 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
103 |
+
sr = sampling_rate
|
104 |
+
|
105 |
+
if segment_duration is not None:
|
106 |
+
seg_length = int(sr * segment_duration)
|
107 |
+
audio = random_select_audio_segment(audio, seg_length)
|
108 |
+
|
109 |
+
# Audio volume normalize
|
110 |
+
if volume_normalize:
|
111 |
+
audio = audio_volume_normalize(audio)
|
112 |
+
# check the audio length
|
113 |
+
if length is not None:
|
114 |
+
assert abs(audio.shape[0] - length) < 1000
|
115 |
+
if audio.shape[0] > length:
|
116 |
+
audio = audio[:length]
|
117 |
+
else:
|
118 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
119 |
+
return audio
|
120 |
+
|
121 |
+
|
122 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
123 |
+
"""get an audio segment given the length
|
124 |
+
|
125 |
+
Args:
|
126 |
+
audio (np.ndarray):
|
127 |
+
length (int): audio length = sampling_rate * duration
|
128 |
+
"""
|
129 |
+
if audio.shape[0] < length:
|
130 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
131 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
132 |
+
end_index = int(start_index + length)
|
133 |
+
|
134 |
+
return audio[start_index:end_index]
|
135 |
+
|
136 |
+
|
137 |
+
def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
|
138 |
+
"""apply highpass fileter to audio
|
139 |
+
|
140 |
+
Args:
|
141 |
+
audio (np.ndarray):
|
142 |
+
sample_rate (ind):
|
143 |
+
highpass_cutoff_freq (int):
|
144 |
+
"""
|
145 |
+
|
146 |
+
audio = torchaudio.functional.highpass_biquad(
|
147 |
+
torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
|
148 |
+
)
|
149 |
+
return audio.numpy()
|
150 |
+
|
151 |
+
|
152 |
+
def stft(
|
153 |
+
x: torch.Tensor,
|
154 |
+
fft_size: int,
|
155 |
+
hop_size: int,
|
156 |
+
win_length: int,
|
157 |
+
window: str,
|
158 |
+
use_complex: bool = False,
|
159 |
+
) -> torch.Tensor:
|
160 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
161 |
+
Args:
|
162 |
+
x (Tensor): Input signal tensor (B, T).
|
163 |
+
fft_size (int): FFT size.
|
164 |
+
hop_size (int): Hop size.
|
165 |
+
win_length (int): Window length.
|
166 |
+
window (str): Window function type.
|
167 |
+
Returns:
|
168 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
169 |
+
"""
|
170 |
+
|
171 |
+
x_stft = torch.stft(
|
172 |
+
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
|
173 |
+
)
|
174 |
+
|
175 |
+
# clamp is needed to avoid nan or inf
|
176 |
+
if not use_complex:
|
177 |
+
return torch.sqrt(
|
178 |
+
torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
|
179 |
+
).transpose(2, 1)
|
180 |
+
else:
|
181 |
+
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
|
182 |
+
res = res.transpose(2, 3) # [B, 2, T, F]
|
183 |
+
return res
|
184 |
+
|
185 |
+
|
186 |
+
def detect_speech_boundaries(
|
187 |
+
wav: np.ndarray,
|
188 |
+
sample_rate: int,
|
189 |
+
window_duration: float = 0.1,
|
190 |
+
energy_threshold: float = 0.01,
|
191 |
+
margin_factor: int = 2
|
192 |
+
) -> Tuple[int, int]:
|
193 |
+
"""Detect the start and end points of speech in an audio signal using RMS energy.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
wav: Input audio signal array with values in [-1, 1]
|
197 |
+
sample_rate: Audio sample rate in Hz
|
198 |
+
window_duration: Duration of detection window in seconds
|
199 |
+
energy_threshold: RMS energy threshold for speech detection
|
200 |
+
margin_factor: Factor to determine extra margin around detected boundaries
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
tuple: (start_index, end_index) of speech segment
|
204 |
+
|
205 |
+
Raises:
|
206 |
+
ValueError: If the audio contains only silence
|
207 |
+
"""
|
208 |
+
window_size = int(window_duration * sample_rate)
|
209 |
+
margin = margin_factor * window_size
|
210 |
+
step_size = window_size // 10
|
211 |
+
|
212 |
+
# Create sliding windows using stride tricks to avoid loops
|
213 |
+
windows = sliding_window_view(wav, window_size)[::step_size]
|
214 |
+
|
215 |
+
# Calculate RMS energy for each window
|
216 |
+
energy = np.sqrt(np.mean(windows ** 2, axis=1))
|
217 |
+
speech_mask = energy >= energy_threshold
|
218 |
+
|
219 |
+
if not np.any(speech_mask):
|
220 |
+
raise ValueError("No speech detected in audio (only silence)")
|
221 |
+
|
222 |
+
start = max(0, np.argmax(speech_mask) * step_size - margin)
|
223 |
+
end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
|
224 |
+
|
225 |
+
return start, end
|
226 |
+
|
227 |
+
|
228 |
+
def remove_silence_on_both_ends(
|
229 |
+
wav: np.ndarray,
|
230 |
+
sample_rate: int,
|
231 |
+
window_duration: float = 0.1,
|
232 |
+
volume_threshold: float = 0.01
|
233 |
+
) -> np.ndarray:
|
234 |
+
"""Remove silence from both ends of an audio signal.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
wav: Input audio signal array
|
238 |
+
sample_rate: Audio sample rate in Hz
|
239 |
+
window_duration: Duration of detection window in seconds
|
240 |
+
volume_threshold: Amplitude threshold for silence detection
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
np.ndarray: Audio signal with silence removed from both ends
|
244 |
+
|
245 |
+
Raises:
|
246 |
+
ValueError: If the audio contains only silence
|
247 |
+
"""
|
248 |
+
start, end = detect_speech_boundaries(
|
249 |
+
wav,
|
250 |
+
sample_rate,
|
251 |
+
window_duration,
|
252 |
+
volume_threshold
|
253 |
+
)
|
254 |
+
return wav[start:end]
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
def hertz_to_mel(pitch: float) -> float:
|
259 |
+
"""
|
260 |
+
Converts a frequency from the Hertz scale to the Mel scale.
|
261 |
+
|
262 |
+
Parameters:
|
263 |
+
- pitch: float or ndarray
|
264 |
+
Frequency in Hertz.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
- mel: float or ndarray
|
268 |
+
Frequency in Mel scale.
|
269 |
+
"""
|
270 |
+
mel = 2595 * np.log10(1 + pitch / 700)
|
271 |
+
return mel
|
sparktts/utils/file.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Description:
|
17 |
+
This script contains a collection of functions designed to handle various
|
18 |
+
file reading and writing operations. It provides utilities to read from files,
|
19 |
+
write data to files, and perform file manipulation tasks.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
import os
|
24 |
+
import json
|
25 |
+
import json
|
26 |
+
import csv
|
27 |
+
|
28 |
+
from tqdm import tqdm
|
29 |
+
from typing import List, Dict, Any, Set, Union
|
30 |
+
from pathlib import Path
|
31 |
+
from omegaconf import OmegaConf, DictConfig
|
32 |
+
|
33 |
+
|
34 |
+
def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
|
35 |
+
"""
|
36 |
+
Resolves the absolute path of a symbolic link.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
symbolic_link_path (Path): The path to the symbolic link.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Path: The absolute path that the symbolic link points to.
|
43 |
+
"""
|
44 |
+
|
45 |
+
link_directory = os.path.dirname(symbolic_link_path)
|
46 |
+
target_path_relative = os.readlink(symbolic_link_path)
|
47 |
+
return os.path.join(link_directory, target_path_relative)
|
48 |
+
|
49 |
+
|
50 |
+
def write_jsonl(metadata: List[dict], file_path: Path) -> None:
|
51 |
+
"""Writes a list of dictionaries to a JSONL file.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
metadata : List[dict]
|
55 |
+
A list of dictionaries, each representing a piece of meta.
|
56 |
+
file_path : Path
|
57 |
+
The file path to save the JSONL file
|
58 |
+
|
59 |
+
This function writes each dictionary in the list to a new line in the specified file.
|
60 |
+
"""
|
61 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
62 |
+
for meta in tqdm(metadata, desc="writing jsonl"):
|
63 |
+
# Convert dictionary to JSON string and write it to the file with a newline
|
64 |
+
json_str = json.dumps(meta, ensure_ascii=False) + "\n"
|
65 |
+
f.write(json_str)
|
66 |
+
print(f"jsonl saved to {file_path}")
|
67 |
+
|
68 |
+
|
69 |
+
def read_jsonl(file_path: Path) -> List[dict]:
|
70 |
+
"""
|
71 |
+
Reads a JSONL file and returns a list of dictionaries.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
file_path : Path
|
75 |
+
The path to the JSONL file to be read.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
List[dict]
|
79 |
+
A list of dictionaries parsed from each line of the JSONL file.
|
80 |
+
"""
|
81 |
+
metadata = []
|
82 |
+
# Open the file for reading
|
83 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
84 |
+
# Split the file into lines
|
85 |
+
lines = f.read().splitlines()
|
86 |
+
# Process each line
|
87 |
+
for line in lines:
|
88 |
+
# Convert JSON string back to dictionary and append to list
|
89 |
+
meta = json.loads(line)
|
90 |
+
metadata.append(meta)
|
91 |
+
# Return the list of metadata
|
92 |
+
return metadata
|
93 |
+
|
94 |
+
def read_json_as_jsonl(file_path: Path) -> List[dict]:
|
95 |
+
metadata = []
|
96 |
+
with open(file_path, 'r', encoding='utf-8') as infile:
|
97 |
+
data = json.load(infile)
|
98 |
+
for k in sorted(data.keys()):
|
99 |
+
meta = {'index': k}
|
100 |
+
meta.update(data[k])
|
101 |
+
metadata.append(meta)
|
102 |
+
return metadata
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
|
107 |
+
processed_meta = {}
|
108 |
+
for k, v in meta.items():
|
109 |
+
if isinstance(v, str):
|
110 |
+
processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
|
111 |
+
else:
|
112 |
+
processed_meta[k] = v
|
113 |
+
return processed_meta
|
114 |
+
|
115 |
+
|
116 |
+
def load_config(config_path: Path) -> DictConfig:
|
117 |
+
"""Loads a configuration file and optionally merges it with a base configuration.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
config_path (Path): Path to the configuration file.
|
121 |
+
"""
|
122 |
+
# Load the initial configuration from the given path
|
123 |
+
config = OmegaConf.load(config_path)
|
124 |
+
|
125 |
+
# Check if there is a base configuration specified and merge if necessary
|
126 |
+
if config.get("base_config", None) is not None:
|
127 |
+
base_config = OmegaConf.load(config["base_config"])
|
128 |
+
config = OmegaConf.merge(base_config, config)
|
129 |
+
|
130 |
+
return config
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
|
135 |
+
"""
|
136 |
+
Converts a JSONL file to a CSV file.
|
137 |
+
|
138 |
+
This function reads a JSONL file, determines all unique keys present in the file,
|
139 |
+
and writes the data to a CSV file with columns for all these keys.
|
140 |
+
"""
|
141 |
+
|
142 |
+
all_keys = set()
|
143 |
+
data_rows = []
|
144 |
+
|
145 |
+
# Read the JSONL file once to extract keys and collect data
|
146 |
+
with open(jsonl_file_path, 'r') as file:
|
147 |
+
for line in file:
|
148 |
+
data = json.loads(line.strip())
|
149 |
+
data_rows.append(data)
|
150 |
+
all_keys.update(data.keys())
|
151 |
+
|
152 |
+
# Convert the set of keys to a sorted list for consistent column order
|
153 |
+
sorted_keys = sorted(all_keys)
|
154 |
+
|
155 |
+
# Write the data to a CSV file
|
156 |
+
with open(csv_file_path, 'w', newline='') as csvfile:
|
157 |
+
writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
|
158 |
+
|
159 |
+
# Write the header row
|
160 |
+
writer.writeheader()
|
161 |
+
|
162 |
+
# Write each row of data
|
163 |
+
for data in data_rows:
|
164 |
+
writer.writerow(data)
|
165 |
+
|
166 |
+
print(f"CSV file has been created at {csv_file_path}")
|
167 |
+
|
168 |
+
|
169 |
+
def save_metadata(data, filename, headers=None):
|
170 |
+
"""
|
171 |
+
Save metadata to a file.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
data (list of dict): Metadata to be saved.
|
175 |
+
filename (str): Name of the file to save the metadata.
|
176 |
+
headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
|
177 |
+
"""
|
178 |
+
# Set headers to keys from the first dictionary in data if not explicitly provided
|
179 |
+
if headers is None:
|
180 |
+
headers = list(data[0].keys())
|
181 |
+
|
182 |
+
with open(filename, "w", encoding="utf-8") as file:
|
183 |
+
# Write the headers to the file
|
184 |
+
file.write("|".join(headers) + "\n")
|
185 |
+
for entry in data:
|
186 |
+
# Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
|
187 |
+
formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
|
188 |
+
# Write the formatted values to the file
|
189 |
+
file.write("|".join(formatted_values) + "\n")
|
190 |
+
|
191 |
+
|
192 |
+
def read_metadata(filename, headers=None):
|
193 |
+
"""
|
194 |
+
Read metadata from a file.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
filename (str): The file from which to read the metadata.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
list of dict: The metadata read from the file.
|
201 |
+
list of str: The headers used in the file.
|
202 |
+
"""
|
203 |
+
with open(filename, "r", encoding="utf-8") as file:
|
204 |
+
lines = file.readlines()
|
205 |
+
|
206 |
+
data = []
|
207 |
+
# Set headers from the first line of the file if not provided
|
208 |
+
if headers is None:
|
209 |
+
headers = lines[0].strip().split("|")
|
210 |
+
lines = lines[1:]
|
211 |
+
|
212 |
+
for line in lines:
|
213 |
+
line = line.strip()
|
214 |
+
# Skip empty lines
|
215 |
+
if not line:
|
216 |
+
continue
|
217 |
+
# Split the line by '|' and pair with headers to form a dictionary
|
218 |
+
entry_data = dict(zip(headers, line.split("|")))
|
219 |
+
data.append(entry_data)
|
220 |
+
|
221 |
+
return data, headers
|
sparktts/utils/parse_options.sh
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
4 |
+
# Arnab Ghoshal, Karel Vesely
|
5 |
+
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
13 |
+
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
14 |
+
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
15 |
+
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
16 |
+
# See the Apache 2 License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
|
20 |
+
# Parse command-line options.
|
21 |
+
# To be sourced by another script (as in ". parse_options.sh").
|
22 |
+
# Option format is: --option-name arg
|
23 |
+
# and shell variable "option_name" gets set to value "arg."
|
24 |
+
# The exception is --help, which takes no arguments, but prints the
|
25 |
+
# $help_message variable (if defined).
|
26 |
+
|
27 |
+
|
28 |
+
###
|
29 |
+
### The --config file options have lower priority to command line
|
30 |
+
### options, so we need to import them first...
|
31 |
+
###
|
32 |
+
|
33 |
+
# Now import all the configs specified by command-line, in left-to-right order
|
34 |
+
# for ((argpos=1; argpos<$#; argpos++)); do
|
35 |
+
# if [ "${!argpos}" == "--config" ]; then
|
36 |
+
# argpos_plus1=$((argpos+1))
|
37 |
+
# config=${!argpos_plus1}
|
38 |
+
# [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
39 |
+
# . $config # source the config file.
|
40 |
+
# fi
|
41 |
+
# done
|
42 |
+
|
43 |
+
|
44 |
+
###
|
45 |
+
### No we process the command line options
|
46 |
+
###
|
47 |
+
while true; do
|
48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
49 |
+
case "$1" in
|
50 |
+
# If the enclosing script is called with --help option, print the help
|
51 |
+
# message and exit. Scripts should put help messages in $help_message
|
52 |
+
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
53 |
+
else printf "$help_message\n" 1>&2 ; fi;
|
54 |
+
exit 0 ;;
|
55 |
+
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
56 |
+
exit 1 ;;
|
57 |
+
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
58 |
+
# then work out the variable name as $name, which will equal "foo_bar".
|
59 |
+
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
60 |
+
# Next we test whether the variable in question is undefned-- if so it's
|
61 |
+
# an invalid option and we die. Note: $0 evaluates to the name of the
|
62 |
+
# enclosing script.
|
63 |
+
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
64 |
+
# is undefined. We then have to wrap this test inside "eval" because
|
65 |
+
# foo_bar is itself inside a variable ($name).
|
66 |
+
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
67 |
+
|
68 |
+
oldval="`eval echo \\$$name`";
|
69 |
+
# Work out whether we seem to be expecting a Boolean argument.
|
70 |
+
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
71 |
+
was_bool=true;
|
72 |
+
else
|
73 |
+
was_bool=false;
|
74 |
+
fi
|
75 |
+
|
76 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
77 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
78 |
+
eval $name=\"$2\";
|
79 |
+
|
80 |
+
# Check that Boolean-valued arguments are really Boolean.
|
81 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
82 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
83 |
+
exit 1;
|
84 |
+
fi
|
85 |
+
shift 2;
|
86 |
+
;;
|
87 |
+
*) break;
|
88 |
+
esac
|
89 |
+
done
|
90 |
+
|
91 |
+
|
92 |
+
# Check for an empty argument to the --cmd option, which can easily occur as a
|
93 |
+
# result of scripting errors.
|
94 |
+
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
95 |
+
|
96 |
+
|
97 |
+
true; # so this script returns exit code 0.
|
sparktts/utils/token_parser.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TASK_TOKEN_MAP = {
|
2 |
+
"vc": "<|task_vc|>",
|
3 |
+
"tts": "<|task_tts|>",
|
4 |
+
"asr": "<|task_asr|>",
|
5 |
+
"s2s": "<|task_s2s|>",
|
6 |
+
"t2s": "<|task_t2s|>",
|
7 |
+
"understand": "<|task_understand|>",
|
8 |
+
"caption": "<|task_cap|>",
|
9 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
10 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
11 |
+
"speech_edit": "<|task_edit|>",
|
12 |
+
}
|
13 |
+
|
14 |
+
LEVELS_MAP = {
|
15 |
+
"very_low": 0,
|
16 |
+
"low": 1,
|
17 |
+
"moderate": 2,
|
18 |
+
"high": 3,
|
19 |
+
"very_high": 4,
|
20 |
+
}
|
21 |
+
|
22 |
+
LEVELS_MAP_UI = {
|
23 |
+
1: 'very_low',
|
24 |
+
2: 'low',
|
25 |
+
3: 'moderate',
|
26 |
+
4: 'high',
|
27 |
+
5: 'very_high'
|
28 |
+
}
|
29 |
+
|
30 |
+
GENDER_MAP = {
|
31 |
+
"female": 0,
|
32 |
+
"male": 1,
|
33 |
+
}
|
34 |
+
|
35 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
36 |
+
|
37 |
+
EMO_MAP = {
|
38 |
+
"UNKNOWN": 0,
|
39 |
+
"NEUTRAL": 1,
|
40 |
+
"ANGRY": 2,
|
41 |
+
"HAPPY": 3,
|
42 |
+
"SAD": 4,
|
43 |
+
"FEARFUL": 5,
|
44 |
+
"DISGUSTED": 6,
|
45 |
+
"SURPRISED": 7,
|
46 |
+
"SARCASTIC": 8,
|
47 |
+
"EXCITED": 9,
|
48 |
+
"SLEEPY": 10,
|
49 |
+
"CONFUSED": 11,
|
50 |
+
"EMPHASIS": 12,
|
51 |
+
"LAUGHING": 13,
|
52 |
+
"SINGING": 14,
|
53 |
+
"WORRIED": 15,
|
54 |
+
"WHISPER": 16,
|
55 |
+
"ANXIOUS": 17,
|
56 |
+
"NO-AGREEMENT": 18,
|
57 |
+
"APOLOGETIC": 19,
|
58 |
+
"CONCERNED": 20,
|
59 |
+
"ENUNCIATED": 21,
|
60 |
+
"ASSERTIVE": 22,
|
61 |
+
"ENCOURAGING": 23,
|
62 |
+
"CONTEMPT": 24,
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
class TokenParser:
|
67 |
+
"""Turn label to special token"""
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
pass
|
71 |
+
|
72 |
+
"""Parse the attributes of a person."""
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
pass
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def age(age: str) -> str:
|
79 |
+
"""Turn age token."""
|
80 |
+
age_id = AGE_MAP[age]
|
81 |
+
return f"<|age_{age_id}|>"
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def gender(gender: str) -> str:
|
85 |
+
"""Turn gender token."""
|
86 |
+
gender_id = GENDER_MAP[gender]
|
87 |
+
return f"<|gender_{gender_id}|>"
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def mel_value(mel: int):
|
91 |
+
"""Turn special token of mel scale pitch."""
|
92 |
+
mel = max(0, int(mel))
|
93 |
+
mel = min(1000, int(mel))
|
94 |
+
return f"<|pitch_value_{mel}|>"
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def mel_level(level: str):
|
98 |
+
"""Turn special token of mel level."""
|
99 |
+
level_tag = LEVELS_MAP[level]
|
100 |
+
return f"<|pitch_label_{level_tag}|>"
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def pitch_var_value(pitch_std: int):
|
104 |
+
"""Turn special token of pitch_std value."""
|
105 |
+
assert isinstance(pitch_std, int)
|
106 |
+
pitch_std = max(0, int(pitch_std))
|
107 |
+
pitch_std = min(10, int(pitch_std))
|
108 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def pitch_var_level(level: str):
|
112 |
+
"""Turn special token of pitch std level."""
|
113 |
+
level_tag = LEVELS_MAP[level]
|
114 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def loudness_value(loudness: int):
|
118 |
+
"""Turn special toak of loudness value [0, 30]"""
|
119 |
+
assert loudness >= 0
|
120 |
+
loudness = max(0, int(loudness))
|
121 |
+
loudness = min(30, int(loudness))
|
122 |
+
return f"<|loudness_value_{loudness}|>"
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def loudness_level(level: str):
|
126 |
+
"""Turn special token of loudness level."""
|
127 |
+
level_tag = LEVELS_MAP[level]
|
128 |
+
return f"<|loudness_label_{level_tag}|>"
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def speed_value(speed: int):
|
132 |
+
"""Turn special token of speed value."""
|
133 |
+
speed = max(0, int(speed))
|
134 |
+
speed = min(10, int(speed))
|
135 |
+
return f"<|speed_value_{speed}|>"
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def speed_level(level: str):
|
139 |
+
"""Turn special token of speed level."""
|
140 |
+
level_tag = LEVELS_MAP[level]
|
141 |
+
return f"<|speed_label_{level_tag}|>"
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def task(task: str) -> str:
|
145 |
+
"""Turn special token of task."""
|
146 |
+
assert task in TASK_TOKEN_MAP.keys()
|
147 |
+
|
148 |
+
return TASK_TOKEN_MAP[task]
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def emotion(emotion: str):
|
152 |
+
emo_id = EMO_MAP[emotion]
|
153 |
+
|
154 |
+
return f"<|emotion_{emo_id}|>"
|
155 |
+
|
156 |
+
|
157 |
+
# test
|
158 |
+
if __name__ == "__main__":
|
159 |
+
from transformers import AutoTokenizer
|
160 |
+
|
161 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
162 |
+
"/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
|
163 |
+
)
|
164 |
+
|
165 |
+
tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
|
166 |
+
ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
|
167 |
+
genders = ["female", "female", "female", "male", "male"]
|
168 |
+
mels = [100, 200, 300, 400, 500]
|
169 |
+
mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
170 |
+
loudnesses = [1, 10, 23, 19, 30]
|
171 |
+
loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
172 |
+
emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
|
173 |
+
|
174 |
+
for i in range(5):
|
175 |
+
task = TokenParser.task(tasks[i])
|
176 |
+
age = TokenParser.age(ages[i])
|
177 |
+
gender = TokenParser.gender(genders[i])
|
178 |
+
mel = TokenParser.mel_value(mels[i])
|
179 |
+
mel_level = TokenParser.mel_level(mel_levels[i])
|
180 |
+
loudness = TokenParser.loudness_value(loudnesses[i])
|
181 |
+
loudness_level = TokenParser.loudness_level(loudness_levels[i])
|
182 |
+
emotion = TokenParser.emotion(emotions[i])
|
183 |
+
inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
|
184 |
+
inputs = "".join(inputs)
|
185 |
+
ids = tokenizer.encode(inputs, add_special_tokens=False)
|
186 |
+
print(ids)
|
187 |
+
print("decode", tokenizer.decode(ids))
|
src/demos/trump/trump_en.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64fe26aa5fb432be85a183b6e48f7e1045c5c9fd4b8eb4faeeb5d4df5934f80f
|
3 |
+
size 476204
|
src/demos/zhongli/zhongli_en.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b94df88681b9f8afe07d2081bcc354f8e97c83886a944f37b4979b83547d39f
|
3 |
+
size 389804
|
src/demos/余承东/yuchengdong_zh.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90bcbe608c5c483d34213e18b7778e11ab55998d52bbe1b5f8bdd80f0473e7c2
|
3 |
+
size 496044
|
src/demos/刘德华/dehua_zh.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe97bd7c679be4dfbbd30496bf9b192c43dfae4a497ae7d99e85841ea06e77f2
|
3 |
+
size 772524
|
src/demos/哪吒/nezha_zh.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e09b01a61e6965f42b2bbfd47a33730991276ad7e6892531ba844646f8c9601e
|
3 |
+
size 596524
|